初始化项目,由ModelHub XC社区提供模型
Model: vumichien/wav2vec2-large-xlsr-japanese Source: Original Platform
This commit is contained in:
17
.gitattributes
vendored
Normal file
17
.gitattributes
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||
*.model filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
model.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
521
Fine-Tune-Wav2Vec2-Large-XLSR-Japan.ipynb
Normal file
521
Fine-Tune-Wav2Vec2-Large-XLSR-Japan.ipynb
Normal file
@@ -0,0 +1,521 @@
|
||||
{
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": 3
|
||||
},
|
||||
"orig_nbformat": 2
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2,
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%capture\n",
|
||||
"!pip install datasets==1.4.1\n",
|
||||
"!pip install transformers==4.4.0\n",
|
||||
"!pip install torchaudio\n",
|
||||
"!pip install librosa\n",
|
||||
"!pip install jiwer\n",
|
||||
"!pip install mecab-python3\n",
|
||||
"!pip install unidic-lite\n",
|
||||
"!pip isntall audiomentations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor\n",
|
||||
"from datasets import load_dataset, load_metric, ClassLabel, Dataset\n",
|
||||
"from audiomentations import Compose, AddGaussianNoise, Gain, PitchShift, TimeStretch, Shift\n",
|
||||
"from torch.optim.lr_scheduler import LambdaLR\n",
|
||||
"from transformers import Wav2Vec2ForCTC, TrainingArguments, Trainer\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"import numpy as np\n",
|
||||
"import soundfile as sf\n",
|
||||
"import re\n",
|
||||
"import json\n",
|
||||
"import torchaudio\n",
|
||||
"import librosa\n",
|
||||
"import datasets\n",
|
||||
"import MeCab\n",
|
||||
"import pykakasi\n",
|
||||
"import random\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"from dataclasses import dataclass, field\n",
|
||||
"from typing import Any, Dict, List, Optional, Union"
|
||||
]
|
||||
},
|
||||
{
|
||||
"source": [
|
||||
"# Load dataset and prepare processor"
|
||||
],
|
||||
"cell_type": "markdown",
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load public dataset from University of Tokyo\n",
|
||||
"!wget http://ss-takashi.sakura.ne.jp/corpus/jsut_ver1.1.zip\n",
|
||||
"!unzip jsut_ver1.1.zip\n",
|
||||
"\n",
|
||||
"path = 'jsut_ver1.1/basic5000/'\n",
|
||||
"df = pd.read_csv(path + 'transcript_utf8.txt', header = None, delimiter = \":\", names=[\"path\", \"sentence\"], index_col=False)\n",
|
||||
"df[\"path\"] = df[\"path\"].map(lambda x: path + 'wav/' + x + \".wav\")\n",
|
||||
"df.head()\n",
|
||||
"\n",
|
||||
"jsut_voice_train = Dataset.from_pandas(df)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Import training dataset\n",
|
||||
"common_voice_train = load_dataset('common_voice', 'ja',split='train+validation')\n",
|
||||
"common_voice_test = load_dataset('common_voice', 'ja', split='test')\n",
|
||||
"\n",
|
||||
"# Remove unwanted columns\n",
|
||||
"common_voice_train = common_voice_train.remove_columns([\"accent\", \"age\", \"client_id\", \"down_votes\", \"gender\", \"locale\", \"segment\", \"up_votes\"])\n",
|
||||
"common_voice_test = common_voice_test.remove_columns([\"accent\", \"age\", \"client_id\", \"down_votes\", \"gender\", \"locale\", \"segment\", \"up_votes\"])\n",
|
||||
"\n",
|
||||
"# Concat common voice and public dataset\n",
|
||||
"common_voice_train = datasets.concatenate_datasets([jsut_voice_train, common_voice_train])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Parser Japanese sentence. Ex: \"pythonが大好きです\" -> \"python が 大好き です EOS\"\n",
|
||||
"wakati = MeCab.Tagger(\"-Owakati\")\n",
|
||||
"\n",
|
||||
"# Unwanted token\n",
|
||||
"chars_to_ignore_regex = '[\\,\\、\\。\\.\\「\\」\\…\\?\\・]'\n",
|
||||
"\n",
|
||||
"def remove_special_characters(batch):\n",
|
||||
" batch[\"sentence\"] = wakati.parse(batch[\"sentence\"]).strip()\n",
|
||||
" batch[\"sentence\"] = re.sub(chars_to_ignore_regex,'', batch[\"sentence\"]).strip()\n",
|
||||
" return batch\n",
|
||||
"\n",
|
||||
"common_voice_train = common_voice_train.map(remove_special_characters)\n",
|
||||
"common_voice_test = common_voice_test.map(remove_special_characters)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# make vocab file\n",
|
||||
"def extract_all_chars(batch):\n",
|
||||
" all_text = \" \".join(batch[\"sentence\"])\n",
|
||||
" vocab = list(set(all_text))\n",
|
||||
" return {\"vocab\": [vocab], \"all_text\": [all_text]}\n",
|
||||
"\n",
|
||||
"# make vocab list and text\n",
|
||||
"vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)\n",
|
||||
"vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)\n",
|
||||
"\n",
|
||||
"# concate vocab from train and test set\n",
|
||||
"vocab_list = list(set(vocab_train[\"vocab\"][0]) | set(vocab_test[\"vocab\"][0]))\n",
|
||||
"vocab_dict = {v: k for k, v in enumerate(vocab_list)}\n",
|
||||
"print(len(vocab_dict))\n",
|
||||
"vocab_dict[\"|\"] = vocab_dict[\" \"]\n",
|
||||
"del vocab_dict[\" \"]\n",
|
||||
"\n",
|
||||
"# create unk and pad token\n",
|
||||
"vocab_dict[\"[UNK]\"] = len(vocab_dict)\n",
|
||||
"vocab_dict[\"[PAD]\"] = len(vocab_dict)\n",
|
||||
"\n",
|
||||
"# save to json file\n",
|
||||
"with open('vocab.json', 'w') as vocab_file:\n",
|
||||
" json.dump(vocab_dict, vocab_file, indent=2, ensure_ascii=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"save_dir = \"./output_models\"\n",
|
||||
"# wrap tokenizer and feature extractor to processor\n",
|
||||
"tokenizer = Wav2Vec2CTCTokenizer(\"./vocab_demo.json\", unk_token=\"[UNK]\", pad_token=\"[PAD]\", word_delimiter_token=\"|\")\n",
|
||||
"feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)\n",
|
||||
"processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)\n",
|
||||
"processor.save_pretrained(save_dir)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"source": [
|
||||
"# Prepare train and test dataset "
|
||||
],
|
||||
"cell_type": "markdown",
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# convert audio from 48kHz to 16kHz (standard sample rate of wave2vec model)\n",
|
||||
"def speech_file_to_array_fn(batch):\n",
|
||||
" speech_array, sampling_rate = torchaudio.load(batch[\"path\"])\n",
|
||||
" batch[\"speech\"] = librosa.resample(np.asarray(speech_array[0].numpy()), 48_000, 16_000)\n",
|
||||
" batch[\"sampling_rate\"] = 16_000\n",
|
||||
" batch[\"target_text\"] = batch[\"sentence\"]\n",
|
||||
" return batch\n",
|
||||
"\n",
|
||||
"common_voice_train = common_voice_train.map(speech_file_to_array_fn, remove_columns=common_voice_train.column_names,num_proc=4)\n",
|
||||
"common_voice_test = common_voice_test.map(speech_file_to_array_fn,remove_columns=common_voice_test.column_names, num_proc=4) "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# do augment to enrich common voice dataset \n",
|
||||
"augment = Compose([\n",
|
||||
" AddGaussianNoise(min_amplitude=0.0001, max_amplitude=0.001, p=0.8),\n",
|
||||
" PitchShift(min_semitones=-1, max_semitones=1, p=0.8),\n",
|
||||
" Gain(min_gain_in_db=-6, max_gain_in_db=6, p=0.8),\n",
|
||||
" TimeStretch(min_rate=0.8, max_rate=1.25, p=0.8)\n",
|
||||
"\n",
|
||||
"])\n",
|
||||
"\n",
|
||||
"def augmented_speech(batch, augment):\n",
|
||||
" samples = np.array(batch[\"speech\"])\n",
|
||||
" batch[\"speech\"] = augment(samples=samples, sample_rate=16000)\n",
|
||||
" batch[\"sampling_rate\"] = 16_000\n",
|
||||
" batch[\"target_text\"] = batch[\"target_text\"]\n",
|
||||
" return batch\n",
|
||||
"\n",
|
||||
"# augument 50% of trainset\n",
|
||||
"common_voice_train_augmented = common_voice_train.train_test_split(test_size = 0.5)['train']\n",
|
||||
"common_voice_train_augmented = common_voice_train_augmented.map(lambda batch: augmented_speech(batch, augment), num_proc=4)\n",
|
||||
"\n",
|
||||
"# concate with trainset\n",
|
||||
"common_voice_train = datasets.concatenate_datasets([common_voice_train_augmented, common_voice_train])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def prepare_dataset(batch):\n",
|
||||
" # check that all files have the correct sampling rate\n",
|
||||
" assert (\n",
|
||||
" len(set(batch[\"sampling_rate\"])) == 1\n",
|
||||
" ), f\"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}.\"\n",
|
||||
"\n",
|
||||
" batch[\"input_values\"] = processor(batch[\"speech\"], sampling_rate=batch[\"sampling_rate\"][0]).input_values\n",
|
||||
" \n",
|
||||
" with processor.as_target_processor():\n",
|
||||
" batch[\"labels\"] = processor(batch[\"target_text\"]).input_ids\n",
|
||||
" return batch\n",
|
||||
" \n",
|
||||
"# prepare dataset\n",
|
||||
"common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names, batch_size=8, num_proc=4, batched=True)\n",
|
||||
"common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names, batch_size=8, num_proc=4, batched=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"source": [
|
||||
"# Training"
|
||||
],
|
||||
"cell_type": "markdown",
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# create data collator\n",
|
||||
"@dataclass\n",
|
||||
"class DataCollatorCTCWithPadding:\n",
|
||||
"\n",
|
||||
" processor: Wav2Vec2Processor\n",
|
||||
" padding: Union[bool, str] = True\n",
|
||||
" max_length: Optional[int] = None\n",
|
||||
" max_length_labels: Optional[int] = None\n",
|
||||
" pad_to_multiple_of: Optional[int] = None\n",
|
||||
" pad_to_multiple_of_labels: Optional[int] = None\n",
|
||||
"\n",
|
||||
" def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n",
|
||||
" input_features = [{\"input_values\": feature[\"input_values\"]} for feature in features]\n",
|
||||
" label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n",
|
||||
"\n",
|
||||
" batch = self.processor.pad(\n",
|
||||
" input_features,\n",
|
||||
" padding=self.padding,\n",
|
||||
" max_length=self.max_length,\n",
|
||||
" pad_to_multiple_of=self.pad_to_multiple_of,\n",
|
||||
" return_tensors=\"pt\",\n",
|
||||
" )\n",
|
||||
" with self.processor.as_target_processor():\n",
|
||||
" labels_batch = self.processor.pad(\n",
|
||||
" label_features,\n",
|
||||
" padding=self.padding,\n",
|
||||
" max_length=self.max_length_labels,\n",
|
||||
" pad_to_multiple_of=self.pad_to_multiple_of_labels,\n",
|
||||
" return_tensors=\"pt\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # replace padding with -100 to ignore loss correctly\n",
|
||||
" labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n",
|
||||
"\n",
|
||||
" batch[\"labels\"] = labels\n",
|
||||
"\n",
|
||||
" return batch\n",
|
||||
"\n",
|
||||
"data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# make metric function\n",
|
||||
"wer_metric = load_metric(\"wer\")\n",
|
||||
"\n",
|
||||
"def compute_metrics(pred):\n",
|
||||
" pred_logits = pred.predictions\n",
|
||||
" pred_ids = np.argmax(pred_logits, axis=-1)\n",
|
||||
"\n",
|
||||
" pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id\n",
|
||||
"\n",
|
||||
" pred_str = processor.batch_decode(pred_ids)\n",
|
||||
" # we do not want to group tokens when computing the metrics\n",
|
||||
" label_str = processor.batch_decode(pred.label_ids, group_tokens=False)\n",
|
||||
"\n",
|
||||
" wer = wer_metric.compute(predictions=pred_str, references=label_str)\n",
|
||||
"\n",
|
||||
" return {\"wer\": wer}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# create custom learning scheduler\n",
|
||||
"\n",
|
||||
"# polynomial decay\n",
|
||||
"def get_polynomial_decay_schedule_with_warmup(\n",
|
||||
" optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.2, last_epoch=-1\n",
|
||||
"):\n",
|
||||
"\n",
|
||||
" lr_init = optimizer.defaults[\"lr\"]\n",
|
||||
" assert lr_init > lr_end, f\"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})\"\n",
|
||||
"\n",
|
||||
" def lr_lambda(current_step: int):\n",
|
||||
" if current_step < num_warmup_steps:\n",
|
||||
" return float(current_step) / float(max(1, num_warmup_steps))\n",
|
||||
" elif current_step > num_training_steps:\n",
|
||||
" return lr_end / lr_init # as LambdaLR multiplies by lr_init\n",
|
||||
" else:\n",
|
||||
" lr_range = lr_init - lr_end\n",
|
||||
" decay_steps = num_training_steps - num_warmup_steps\n",
|
||||
" pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps\n",
|
||||
" decay = lr_range * pct_remaining ** power + lr_end\n",
|
||||
" return decay / lr_init # as LambdaLR multiplies by lr_init\n",
|
||||
"\n",
|
||||
" return LambdaLR(optimizer, lr_lambda, last_epoch)\n",
|
||||
" \n",
|
||||
"# wrap custom learning scheduler with trainer\n",
|
||||
"class PolyTrainer(Trainer):\n",
|
||||
" def __init__(self, *args, **kwargs):\n",
|
||||
" super().__init__(*args, **kwargs)\n",
|
||||
" \n",
|
||||
" def create_scheduler(self, num_training_steps: int):\n",
|
||||
" self.lr_scheduler = get_polynomial_decay_schedule_with_warmup(self.optimizer, \n",
|
||||
" num_warmup_steps=self.args.warmup_steps,\n",
|
||||
" num_training_steps=num_training_steps)\n",
|
||||
" def create_optimizer_and_scheduler(self, num_training_steps: int):\n",
|
||||
" self.create_optimizer()\n",
|
||||
" self.create_scheduler(num_training_steps)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# load pretrain model\n",
|
||||
"model = Wav2Vec2ForCTC.from_pretrained(\n",
|
||||
" \"facebook/wav2vec2-large-xlsr-53\", \n",
|
||||
" attention_dropout=0.1,\n",
|
||||
" hidden_dropout=0.1,\n",
|
||||
" feat_proj_dropout=0.1,\n",
|
||||
" mask_time_prob=0.1, \n",
|
||||
" layerdrop=0.1,\n",
|
||||
" gradient_checkpointing=True, \n",
|
||||
" ctc_loss_reduction=\"mean\", \n",
|
||||
" pad_token_id=processor.tokenizer.pad_token_id,\n",
|
||||
" vocab_size=len(processor.tokenizer)\n",
|
||||
")\n",
|
||||
"# free feature extractor\n",
|
||||
"model.freeze_feature_extractor()\n",
|
||||
"\n",
|
||||
"# define train argument\n",
|
||||
"training_args = TrainingArguments(\n",
|
||||
" output_dir=save_dir,\n",
|
||||
" group_by_length=True,\n",
|
||||
" per_device_train_batch_size=32,\n",
|
||||
" gradient_accumulation_steps=2,\n",
|
||||
" evaluation_strategy=\"steps\",\n",
|
||||
" num_train_epochs=200,\n",
|
||||
" fp16=True,\n",
|
||||
" save_steps=2400, \n",
|
||||
" eval_steps=800,\n",
|
||||
" logging_steps=800, \n",
|
||||
" learning_rate=1e-4, \n",
|
||||
" warmup_steps=1500, \n",
|
||||
" save_total_limit=2,\n",
|
||||
" load_best_model_at_end = True, \n",
|
||||
" metric_for_best_model='wer', \n",
|
||||
" greater_is_better=False\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# wrap everything to Trainer\n",
|
||||
"trainer = PolyTrainer(\n",
|
||||
" model=model,\n",
|
||||
" data_collator=data_collator,\n",
|
||||
" args=training_args,\n",
|
||||
" compute_metrics=compute_metrics,\n",
|
||||
" train_dataset=common_voice_train,\n",
|
||||
" eval_dataset=common_voice_test,\n",
|
||||
" tokenizer=processor.feature_extractor,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# training\n",
|
||||
"train_result = trainer.train()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"source": [
|
||||
"# Testing result"
|
||||
],
|
||||
"cell_type": "markdown",
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import torchaudio\n",
|
||||
"from datasets import load_dataset, load_metric\n",
|
||||
"from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor\n",
|
||||
"import MeCab\n",
|
||||
"import pykakasi\n",
|
||||
"import re\n",
|
||||
"\n",
|
||||
"#config\n",
|
||||
"wakati = MeCab.Tagger(\"-Owakati\")\n",
|
||||
"chars_to_ignore_regex = '[\\,\\、\\。\\.\\「\\」\\…\\?\\・]'\n",
|
||||
"\n",
|
||||
"#load model\n",
|
||||
"processor = Wav2Vec2Processor.from_pretrained(save_dir)\n",
|
||||
"test_model = Wav2Vec2ForCTC.from_pretrained(save_dir)\n",
|
||||
"test_model.to(\"cuda\")\n",
|
||||
"resampler = torchaudio.transforms.Resample(48_000, 16_000)\n",
|
||||
"\n",
|
||||
"#load testdata\n",
|
||||
"test_dataset = load_dataset(\"common_voice\", \"ja\", split=\"test\")\n",
|
||||
"wer = load_metric(\"wer\")\n",
|
||||
"\n",
|
||||
"# Preprocessing the datasets.\n",
|
||||
"def speech_file_to_array_fn(batch):\n",
|
||||
" batch[\"sentence\"] = wakati.parse(batch[\"sentence\"]).strip()\n",
|
||||
" batch[\"sentence\"] = re.sub(chars_to_ignore_regex,'', batch[\"sentence\"]).strip()\n",
|
||||
" speech_array, sampling_rate = torchaudio.load(batch[\"path\"])\n",
|
||||
" batch[\"speech\"] = resampler(speech_array).squeeze().numpy()\n",
|
||||
" return batch\n",
|
||||
"\n",
|
||||
"test_dataset = test_dataset.map(speech_file_to_array_fn)\n",
|
||||
"\n",
|
||||
"# Preprocessing the datasets.\n",
|
||||
"# We need to read the aduio files as arrays\n",
|
||||
"def evaluate(batch):\n",
|
||||
" inputs = processor(batch[\"speech\"], sampling_rate=16_000, return_tensors=\"pt\", padding=True)\n",
|
||||
"\n",
|
||||
" with torch.no_grad():\n",
|
||||
" logits = test_model(inputs.input_values.to(\"cuda\"), attention_mask=inputs.attention_mask.to(\"cuda\")).logits\n",
|
||||
" pred_ids = torch.argmax(logits, dim=-1)\n",
|
||||
" batch[\"pred_strings\"] = processor.batch_decode(pred_ids)\n",
|
||||
" return batch\n",
|
||||
"\n",
|
||||
"result = test_dataset.map(evaluate, batched=True, batch_size=8)\n",
|
||||
"\n",
|
||||
"print(\"WER: {:2f}\".format(100 * wer.compute(predictions=result[\"pred_strings\"], references=result[\"sentence\"])))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# print some reusults\n",
|
||||
"pick = random.randint(0, len(common_voice_test_transcription)-1)\n",
|
||||
"input_dict = processor(common_voice_test[\"input_values\"][pick], return_tensors=\"pt\", padding=True)\n",
|
||||
"logits = test_model(input_dict.input_values.to(\"cuda\")).logits\n",
|
||||
"pred_ids = torch.argmax(logits, dim=-1)[0]\n",
|
||||
"\n",
|
||||
"print(\"Prediction:\")\n",
|
||||
"print(processor.decode(pred_ids).strip())\n",
|
||||
"\n",
|
||||
"print(\"\\nLabel:\")\n",
|
||||
"print(processor.decode(common_voice_test['labels'][pick]))\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
131
README.md
Normal file
131
README.md
Normal file
@@ -0,0 +1,131 @@
|
||||
---
|
||||
language: ja
|
||||
datasets:
|
||||
- common_voice
|
||||
metrics:
|
||||
- wer
|
||||
tags:
|
||||
- audio
|
||||
- automatic-speech-recognition
|
||||
- speech
|
||||
- xlsr-fine-tuning-week
|
||||
license: apache-2.0
|
||||
model-index:
|
||||
- name: XLSR Wav2Vec2 Japanese by Chien Vu
|
||||
results:
|
||||
- task:
|
||||
name: Speech Recognition
|
||||
type: automatic-speech-recognition
|
||||
dataset:
|
||||
name: Common Voice Japanese
|
||||
type: common_voice
|
||||
args: ja
|
||||
metrics:
|
||||
- name: Test WER
|
||||
type: wer
|
||||
value: 30.84
|
||||
- name: Test CER
|
||||
type: cer
|
||||
value: 17.85
|
||||
widget:
|
||||
- example_title: Japanese speech corpus sample 1
|
||||
src: https://u.pcloud.link/publink/show?code=XZwhAlXZFOtXiqKHMzmYS9wXrCP8Yb7EtRd7
|
||||
- example_title: Japanese speech corpus sample 2
|
||||
src: https://u.pcloud.link/publink/show?code=XZ6hAlXZ5ccULt0YtrhJFl7LygKg0SJzKX0k
|
||||
---
|
||||
# Wav2Vec2-Large-XLSR-53-Japanese
|
||||
Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) on Japanese using the [Common Voice](https://huggingface.co/datasets/common_voice) and Japanese speech corpus of Saruwatari-lab, University of Tokyo [JSUT](https://sites.google.com/site/shinnosuketakamichi/publication/jsut).
|
||||
When using this model, make sure that your speech input is sampled at 16kHz.
|
||||
## Usage
|
||||
The model can be used directly (without a language model) as follows:
|
||||
```python
|
||||
!pip install mecab-python3
|
||||
!pip install unidic-lite
|
||||
!python -m unidic download
|
||||
import torch
|
||||
import torchaudio
|
||||
import librosa
|
||||
from datasets import load_dataset
|
||||
import MeCab
|
||||
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
||||
import re
|
||||
|
||||
# config
|
||||
wakati = MeCab.Tagger("-Owakati")
|
||||
chars_to_ignore_regex = '[\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\,\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\、\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\。\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\.\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\「\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\」\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\…\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\?\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\・]'
|
||||
|
||||
# load data, processor and model
|
||||
test_dataset = load_dataset("common_voice", "ja", split="test[:2%]")
|
||||
processor = Wav2Vec2Processor.from_pretrained("vumichien/wav2vec2-large-xlsr-japanese")
|
||||
model = Wav2Vec2ForCTC.from_pretrained("vumichien/wav2vec2-large-xlsr-japanese")
|
||||
resampler = lambda sr, y: librosa.resample(y.numpy().squeeze(), sr, 16_000)
|
||||
|
||||
# Preprocessing the datasets.
|
||||
def speech_file_to_array_fn(batch):
|
||||
batch["sentence"] = wakati.parse(batch["sentence"]).strip()
|
||||
batch["sentence"] = re.sub(chars_to_ignore_regex,'', batch["sentence"]).strip()
|
||||
speech_array, sampling_rate = torchaudio.load(batch["path"])
|
||||
batch["speech"] = resampler(sampling_rate, speech_array).squeeze()
|
||||
return batch
|
||||
test_dataset = test_dataset.map(speech_file_to_array_fn)
|
||||
inputs = processor(test_dataset["speech"][:2], sampling_rate=16_000, return_tensors="pt", padding=True)
|
||||
with torch.no_grad():
|
||||
logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
print("Prediction:", processor.batch_decode(predicted_ids))
|
||||
print("Reference:", test_dataset["sentence"][:2])
|
||||
```
|
||||
## Evaluation
|
||||
The model can be evaluated as follows on the Japanese test data of Common Voice.
|
||||
```python
|
||||
!pip install mecab-python3
|
||||
!pip install unidic-lite
|
||||
!python -m unidic download
|
||||
|
||||
import torch
|
||||
import librosa
|
||||
import torchaudio
|
||||
from datasets import load_dataset, load_metric
|
||||
import MeCab
|
||||
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
||||
import re
|
||||
|
||||
#config
|
||||
wakati = MeCab.Tagger("-Owakati")
|
||||
chars_to_ignore_regex = '[\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\,\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\、\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\。\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\.\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\「\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\」\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\…\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\?\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\・]'
|
||||
|
||||
# load data, processor and model
|
||||
test_dataset = load_dataset("common_voice", "ja", split="test")
|
||||
wer = load_metric("wer")
|
||||
processor = Wav2Vec2Processor.from_pretrained("vumichien/wav2vec2-large-xlsr-japanese")
|
||||
model = Wav2Vec2ForCTC.from_pretrained("vumichien/wav2vec2-large-xlsr-japanese")
|
||||
model.to("cuda")
|
||||
resampler = lambda sr, y: librosa.resample(y.numpy().squeeze(), sr, 16_000)
|
||||
|
||||
# Preprocessing the datasets.
|
||||
def speech_file_to_array_fn(batch):
|
||||
batch["sentence"] = wakati.parse(batch["sentence"]).strip()
|
||||
batch["sentence"] = re.sub(chars_to_ignore_regex,'', batch["sentence"]).strip()
|
||||
speech_array, sampling_rate = torchaudio.load(batch["path"])
|
||||
batch["speech"] = resampler(sampling_rate, speech_array).squeeze()
|
||||
return batch
|
||||
test_dataset = test_dataset.map(speech_file_to_array_fn)
|
||||
|
||||
# evaluate function
|
||||
def evaluate(batch):
|
||||
inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
|
||||
with torch.no_grad():
|
||||
logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
|
||||
pred_ids = torch.argmax(logits, dim=-1)
|
||||
batch["pred_strings"] = processor.batch_decode(pred_ids)
|
||||
return batch
|
||||
result = test_dataset.map(evaluate, batched=True, batch_size=8)
|
||||
print("WER: {:2f}".format(100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"])))
|
||||
```
|
||||
## Test Result
|
||||
**WER:** 30.84%,
|
||||
**CER:** 17.85%
|
||||
|
||||
## Training
|
||||
The Common Voice `train`, `validation` datasets and Japanese speech corpus `basic5000` datasets were used for training.
|
||||
|
||||
76
config.json
Normal file
76
config.json
Normal file
@@ -0,0 +1,76 @@
|
||||
{
|
||||
"_name_or_path": "facebook/wav2vec2-large-xlsr-53",
|
||||
"activation_dropout": 0.0,
|
||||
"apply_spec_augment": true,
|
||||
"architectures": [
|
||||
"Wav2Vec2ForCTC"
|
||||
],
|
||||
"attention_dropout": 0.1,
|
||||
"bos_token_id": 1,
|
||||
"conv_bias": true,
|
||||
"conv_dim": [
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512
|
||||
],
|
||||
"conv_kernel": [
|
||||
10,
|
||||
3,
|
||||
3,
|
||||
3,
|
||||
3,
|
||||
2,
|
||||
2
|
||||
],
|
||||
"conv_stride": [
|
||||
5,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2
|
||||
],
|
||||
"ctc_loss_reduction": "mean",
|
||||
"ctc_zero_infinity": false,
|
||||
"do_stable_layer_norm": true,
|
||||
"eos_token_id": 2,
|
||||
"feat_extract_activation": "gelu",
|
||||
"feat_extract_dropout": 0.0,
|
||||
"feat_extract_norm": "layer",
|
||||
"feat_proj_dropout": 0.1,
|
||||
"final_dropout": 0.0,
|
||||
"gradient_checkpointing": true,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout": 0.1,
|
||||
"hidden_size": 1024,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"layerdrop": 0.1,
|
||||
"mask_channel_length": 10,
|
||||
"mask_channel_min_space": 1,
|
||||
"mask_channel_other": 0.0,
|
||||
"mask_channel_prob": 0.0,
|
||||
"mask_channel_selection": "static",
|
||||
"mask_feature_length": 10,
|
||||
"mask_feature_prob": 0.0,
|
||||
"mask_time_length": 10,
|
||||
"mask_time_min_space": 1,
|
||||
"mask_time_other": 0.0,
|
||||
"mask_time_prob": 0.1,
|
||||
"mask_time_selection": "static",
|
||||
"model_type": "wav2vec2",
|
||||
"num_attention_heads": 16,
|
||||
"num_conv_pos_embedding_groups": 16,
|
||||
"num_conv_pos_embeddings": 128,
|
||||
"num_feat_extract_layers": 7,
|
||||
"num_hidden_layers": 24,
|
||||
"pad_token_id": 2698,
|
||||
"transformers_version": "4.5.0.dev0",
|
||||
"vocab_size": 2699
|
||||
}
|
||||
3
model.safetensors
Normal file
3
model.safetensors
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c0c568b20eaccdfa8f8bbbcc87ae575473b519413853b0faf97ae88eab02a876
|
||||
size 1272873364
|
||||
8
preprocessor_config.json
Normal file
8
preprocessor_config.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"do_normalize": true,
|
||||
"feature_size": 1,
|
||||
"padding_side": "right",
|
||||
"padding_value": 0.0,
|
||||
"return_attention_mask": true,
|
||||
"sampling_rate": 16000
|
||||
}
|
||||
3
pytorch_model.bin
Normal file
3
pytorch_model.bin
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9f20aa1dd1aea1e481a5e87034fe4733dabfb16d6c44096dfc55ecef8fd0777c
|
||||
size 1272999703
|
||||
1
special_tokens_map.json
Normal file
1
special_tokens_map.json
Normal file
@@ -0,0 +1 @@
|
||||
{"bos_token": "<s>", "eos_token": "</s>", "unk_token": "[UNK]", "pad_token": "[PAD]"}
|
||||
1
tokenizer_config.json
Normal file
1
tokenizer_config.json
Normal file
@@ -0,0 +1 @@
|
||||
{"unk_token": "[UNK]", "bos_token": "<s>", "eos_token": "</s>", "pad_token": "[PAD]", "do_lower_case": false, "word_delimiter_token": "|"}
|
||||
3
training_args.bin
Normal file
3
training_args.bin
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:540a9cb029ba4c351fdb2719b51eb95f3c1683ee51cb25588127285fcfe59420
|
||||
size 2351
|
||||
1
vocab.json
Normal file
1
vocab.json
Normal file
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user