commit 4bc6b41f80df32935597919ac53787dcec4b6ebc Author: ModelHub XC Date: Wed May 20 23:00:36 2026 +0800 初始化项目,由ModelHub XC社区提供模型 Model: imvladikon/wav2vec2-xls-r-300m-hebrew Source: Original Platform diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..2bf9985 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,28 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bin.* filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zstandard filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +model.safetensors filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0348ea9 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +checkpoint-*/ \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..8611593 --- /dev/null +++ b/README.md @@ -0,0 +1,251 @@ +--- +language: +- he +tags: +- automatic-speech-recognition +- generated_from_trainer +- he +- hf-asr-leaderboard +- robust-speech-event +base_model: facebook/wav2vec2-xls-r-300m +model-index: +- name: wav2vec2-xls-r-300m-hebrew + results: + - task: + type: automatic-speech-recognition + name: Automatic Speech Recognition + dataset: + name: Custom Dataset + type: custom + args: he + metrics: + - type: wer + value: 23.18 + name: Test WER +--- + + + +# wav2vec2-xls-r-300m-hebrew + +This model is a fine-tuned version of [facebook/wav2vec2-xls-r-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) on the private datasets in 2 stages - firstly was fine-tuned on a small dataset with good samples Then the obtained model was fine-tuned on a large dataset with the small good dataset, with various samples from different sources, and with an unlabeled dataset that was weakly labeled using a previously trained model. + +Small dataset: + +| split |size(gb) | n_samples | duration(hrs)| | +|---|---|---|---|---| +|train|4.19| 20306 | 28 | | +|dev |1.05| 5076 | 7 | | + +Large dataset: + +| split |size(gb) | n_samples | duration(hrs)| | +|---|---|---|---|---| +|train|12.3| 90777 | 69 | | +|dev |2.39| 20246 | 14* | | +(*weakly labeled data wasn't used in validation set) + +After firts training it achieves: + +on small dataset +- Loss: 0.5438 +- WER: 0.1773 + +on large dataset +- WER: 0.3811 + +after second training: +on small dataset +- WER: 0.1697 + +on large dataset +- Loss: 0.4502 +- WER: 0.2318 + +## Model description + +More information needed + +## Intended uses & limitations + +More information needed + +## Training and evaluation data + +More information needed + +## Training procedure + +### Training hyperparameters + + +#### First training + +The following hyperparameters were used during training: +- learning_rate: 0.0003 +- train_batch_size: 8 +- eval_batch_size: 8 +- seed: 42 +- distributed_type: multi-GPU +- num_devices: 2 +- gradient_accumulation_steps: 4 +- total_train_batch_size: 64 +- total_eval_batch_size: 16 +- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08 +- lr_scheduler_type: linear +- lr_scheduler_warmup_steps: 1000 +- num_epochs: 100.0 +- mixed_precision_training: Native AMP + +Training results + +| Training Loss | Epoch | Step | Validation Loss | Wer | +|:-------------:|:-----:|:-----:|:---------------:|:------:| +| No log | 3.15 | 1000 | 0.5203 | 0.4333 | +| 1.4284 | 6.31 | 2000 | 0.4816 | 0.3951 | +| 1.4284 | 9.46 | 3000 | 0.4315 | 0.3546 | +| 1.283 | 12.62 | 4000 | 0.4278 | 0.3404 | +| 1.283 | 15.77 | 5000 | 0.4090 | 0.3054 | +| 1.1777 | 18.93 | 6000 | 0.3893 | 0.3006 | +| 1.1777 | 22.08 | 7000 | 0.3968 | 0.2857 | +| 1.0994 | 25.24 | 8000 | 0.3892 | 0.2751 | +| 1.0994 | 28.39 | 9000 | 0.4061 | 0.2690 | +| 1.0323 | 31.54 | 10000 | 0.4114 | 0.2507 | +| 1.0323 | 34.7 | 11000 | 0.4021 | 0.2508 | +| 0.9623 | 37.85 | 12000 | 0.4032 | 0.2378 | +| 0.9623 | 41.01 | 13000 | 0.4148 | 0.2374 | +| 0.9077 | 44.16 | 14000 | 0.4350 | 0.2323 | +| 0.9077 | 47.32 | 15000 | 0.4515 | 0.2246 | +| 0.8573 | 50.47 | 16000 | 0.4474 | 0.2180 | +| 0.8573 | 53.63 | 17000 | 0.4649 | 0.2171 | +| 0.8083 | 56.78 | 18000 | 0.4455 | 0.2102 | +| 0.8083 | 59.94 | 19000 | 0.4587 | 0.2092 | +| 0.769 | 63.09 | 20000 | 0.4794 | 0.2012 | +| 0.769 | 66.25 | 21000 | 0.4845 | 0.2007 | +| 0.7308 | 69.4 | 22000 | 0.4937 | 0.2008 | +| 0.7308 | 72.55 | 23000 | 0.4920 | 0.1895 | +| 0.6927 | 75.71 | 24000 | 0.5179 | 0.1911 | +| 0.6927 | 78.86 | 25000 | 0.5202 | 0.1877 | +| 0.6622 | 82.02 | 26000 | 0.5266 | 0.1840 | +| 0.6622 | 85.17 | 27000 | 0.5351 | 0.1854 | +| 0.6315 | 88.33 | 28000 | 0.5373 | 0.1811 | +| 0.6315 | 91.48 | 29000 | 0.5331 | 0.1792 | +| 0.6075 | 94.64 | 30000 | 0.5390 | 0.1779 | +| 0.6075 | 97.79 | 31000 | 0.5459 | 0.1773 | + +#### Second training + +The following hyperparameters were used during training: +- learning_rate: 0.0003 +- train_batch_size: 8 +- eval_batch_size: 8 +- seed: 42 +- distributed_type: multi-GPU +- num_devices: 2 +- gradient_accumulation_steps: 4 +- total_train_batch_size: 64 +- total_eval_batch_size: 16 +- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08 +- lr_scheduler_type: linear +- lr_scheduler_warmup_steps: 1000 +- num_epochs: 60.0 +- mixed_precision_training: Native AMP + +### Training results + +| Training Loss | Epoch | Step | Validation Loss | Wer | +|:-------------:|:-----:|:-----:|:---------------:|:------:| +| No log | 0.7 | 1000 | 0.5371 | 0.3811 | +| 1.3606 | 1.41 | 2000 | 0.5247 | 0.3902 | +| 1.3606 | 2.12 | 3000 | 0.5126 | 0.3859 | +| 1.3671 | 2.82 | 4000 | 0.5062 | 0.3828 | +| 1.3671 | 3.53 | 5000 | 0.4979 | 0.3672 | +| 1.3421 | 4.23 | 6000 | 0.4906 | 0.3816 | +| 1.3421 | 4.94 | 7000 | 0.4784 | 0.3651 | +| 1.328 | 5.64 | 8000 | 0.4810 | 0.3669 | +| 1.328 | 6.35 | 9000 | 0.4747 | 0.3597 | +| 1.3109 | 7.05 | 10000 | 0.4813 | 0.3808 | +| 1.3109 | 7.76 | 11000 | 0.4631 | 0.3561 | +| 1.2873 | 8.46 | 12000 | 0.4603 | 0.3431 | +| 1.2873 | 9.17 | 13000 | 0.4579 | 0.3533 | +| 1.2661 | 9.87 | 14000 | 0.4471 | 0.3365 | +| 1.2661 | 10.58 | 15000 | 0.4584 | 0.3437 | +| 1.249 | 11.28 | 16000 | 0.4461 | 0.3454 | +| 1.249 | 11.99 | 17000 | 0.4482 | 0.3367 | +| 1.2322 | 12.69 | 18000 | 0.4464 | 0.3335 | +| 1.2322 | 13.4 | 19000 | 0.4427 | 0.3454 | +| 1.22 | 14.1 | 20000 | 0.4440 | 0.3395 | +| 1.22 | 14.81 | 21000 | 0.4459 | 0.3378 | +| 1.2044 | 15.51 | 22000 | 0.4406 | 0.3199 | +| 1.2044 | 16.22 | 23000 | 0.4398 | 0.3155 | +| 1.1913 | 16.92 | 24000 | 0.4237 | 0.3150 | +| 1.1913 | 17.63 | 25000 | 0.4287 | 0.3279 | +| 1.1705 | 18.34 | 26000 | 0.4253 | 0.3103 | +| 1.1705 | 19.04 | 27000 | 0.4234 | 0.3098 | +| 1.1564 | 19.75 | 28000 | 0.4174 | 0.3076 | +| 1.1564 | 20.45 | 29000 | 0.4260 | 0.3160 | +| 1.1461 | 21.16 | 30000 | 0.4235 | 0.3036 | +| 1.1461 | 21.86 | 31000 | 0.4309 | 0.3055 | +| 1.1285 | 22.57 | 32000 | 0.4264 | 0.3006 | +| 1.1285 | 23.27 | 33000 | 0.4201 | 0.2880 | +| 1.1135 | 23.98 | 34000 | 0.4131 | 0.2975 | +| 1.1135 | 24.68 | 35000 | 0.4202 | 0.2849 | +| 1.0968 | 25.39 | 36000 | 0.4105 | 0.2888 | +| 1.0968 | 26.09 | 37000 | 0.4210 | 0.2834 | +| 1.087 | 26.8 | 38000 | 0.4123 | 0.2843 | +| 1.087 | 27.5 | 39000 | 0.4216 | 0.2803 | +| 1.0707 | 28.21 | 40000 | 0.4161 | 0.2787 | +| 1.0707 | 28.91 | 41000 | 0.4186 | 0.2740 | +| 1.0575 | 29.62 | 42000 | 0.4118 | 0.2845 | +| 1.0575 | 30.32 | 43000 | 0.4243 | 0.2773 | +| 1.0474 | 31.03 | 44000 | 0.4221 | 0.2707 | +| 1.0474 | 31.73 | 45000 | 0.4138 | 0.2700 | +| 1.0333 | 32.44 | 46000 | 0.4102 | 0.2638 | +| 1.0333 | 33.15 | 47000 | 0.4162 | 0.2650 | +| 1.0191 | 33.85 | 48000 | 0.4155 | 0.2636 | +| 1.0191 | 34.56 | 49000 | 0.4129 | 0.2656 | +| 1.0087 | 35.26 | 50000 | 0.4157 | 0.2632 | +| 1.0087 | 35.97 | 51000 | 0.4090 | 0.2654 | +| 0.9901 | 36.67 | 52000 | 0.4183 | 0.2587 | +| 0.9901 | 37.38 | 53000 | 0.4251 | 0.2648 | +| 0.9795 | 38.08 | 54000 | 0.4229 | 0.2555 | +| 0.9795 | 38.79 | 55000 | 0.4176 | 0.2546 | +| 0.9644 | 39.49 | 56000 | 0.4223 | 0.2513 | +| 0.9644 | 40.2 | 57000 | 0.4244 | 0.2530 | +| 0.9534 | 40.9 | 58000 | 0.4175 | 0.2538 | +| 0.9534 | 41.61 | 59000 | 0.4213 | 0.2505 | +| 0.9397 | 42.31 | 60000 | 0.4275 | 0.2565 | +| 0.9397 | 43.02 | 61000 | 0.4315 | 0.2528 | +| 0.9269 | 43.72 | 62000 | 0.4316 | 0.2501 | +| 0.9269 | 44.43 | 63000 | 0.4247 | 0.2471 | +| 0.9175 | 45.13 | 64000 | 0.4376 | 0.2469 | +| 0.9175 | 45.84 | 65000 | 0.4335 | 0.2450 | +| 0.9026 | 46.54 | 66000 | 0.4336 | 0.2452 | +| 0.9026 | 47.25 | 67000 | 0.4400 | 0.2427 | +| 0.8929 | 47.95 | 68000 | 0.4382 | 0.2429 | +| 0.8929 | 48.66 | 69000 | 0.4361 | 0.2415 | +| 0.8786 | 49.37 | 70000 | 0.4413 | 0.2398 | +| 0.8786 | 50.07 | 71000 | 0.4392 | 0.2415 | +| 0.8714 | 50.78 | 72000 | 0.4345 | 0.2406 | +| 0.8714 | 51.48 | 73000 | 0.4475 | 0.2402 | +| 0.8589 | 52.19 | 74000 | 0.4473 | 0.2374 | +| 0.8589 | 52.89 | 75000 | 0.4457 | 0.2357 | +| 0.8493 | 53.6 | 76000 | 0.4462 | 0.2366 | +| 0.8493 | 54.3 | 77000 | 0.4494 | 0.2356 | +| 0.8395 | 55.01 | 78000 | 0.4472 | 0.2352 | +| 0.8395 | 55.71 | 79000 | 0.4490 | 0.2339 | +| 0.8295 | 56.42 | 80000 | 0.4489 | 0.2318 | +| 0.8295 | 57.12 | 81000 | 0.4469 | 0.2320 | +| 0.8225 | 57.83 | 82000 | 0.4478 | 0.2321 | +| 0.8225 | 58.53 | 83000 | 0.4525 | 0.2326 | +| 0.816 | 59.24 | 84000 | 0.4532 | 0.2316 | +| 0.816 | 59.94 | 85000 | 0.4502 | 0.2318 | + + +### Framework versions + +- Transformers 4.17.0.dev0 +- Pytorch 1.10.2+cu102 +- Datasets 1.18.2.dev0 +- Tokenizers 0.11.0 diff --git a/added_tokens.json b/added_tokens.json new file mode 100644 index 0000000..2dc407e --- /dev/null +++ b/added_tokens.json @@ -0,0 +1 @@ +{"": 30, "": 31} \ No newline at end of file diff --git a/all_results.json b/all_results.json new file mode 100644 index 0000000..acb8f5e --- /dev/null +++ b/all_results.json @@ -0,0 +1,14 @@ +{ + "epoch": 60.0, + "eval_loss": 0.4502493739128113, + "eval_runtime": 418.5824, + "eval_samples": 20246, + "eval_samples_per_second": 48.368, + "eval_steps_per_second": 3.024, + "eval_wer": 0.23182961772311134, + "train_loss": 1.0617500136590194, + "train_runtime": 191647.3893, + "train_samples": 90777, + "train_samples_per_second": 28.42, + "train_steps_per_second": 0.444 +} \ No newline at end of file diff --git a/config.json b/config.json new file mode 100644 index 0000000..df05220 --- /dev/null +++ b/config.json @@ -0,0 +1,107 @@ +{ + "_name_or_path": "imvladikon/wav2vec2-xls-r-300m-hebrew", + "activation_dropout": 0.1, + "adapter_kernel_size": 3, + "adapter_stride": 2, + "add_adapter": false, + "apply_spec_augment": true, + "architectures": [ + "Wav2Vec2ForCTC" + ], + "attention_dropout": 0.0, + "bos_token_id": 1, + "classifier_proj_size": 256, + "codevector_dim": 768, + "contrastive_logits_temperature": 0.1, + "conv_bias": true, + "conv_dim": [ + 512, + 512, + 512, + 512, + 512, + 512, + 512 + ], + "conv_kernel": [ + 10, + 3, + 3, + 3, + 3, + 2, + 2 + ], + "conv_stride": [ + 5, + 2, + 2, + 2, + 2, + 2, + 2 + ], + "ctc_loss_reduction": "mean", + "ctc_zero_infinity": false, + "diversity_loss_weight": 0.1, + "do_stable_layer_norm": true, + "eos_token_id": 2, + "feat_extract_activation": "gelu", + "feat_extract_dropout": 0.0, + "feat_extract_norm": "layer", + "feat_proj_dropout": 0.0, + "feat_quantizer_dropout": 0.0, + "final_dropout": 0.0, + "hidden_act": "gelu", + "hidden_dropout": 0.0, + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 4096, + "layer_norm_eps": 1e-05, + "layerdrop": 0.0, + "mask_feature_length": 64, + "mask_feature_min_masks": 0, + "mask_feature_prob": 0.25, + "mask_time_length": 10, + "mask_time_min_masks": 2, + "mask_time_prob": 0.75, + "model_type": "wav2vec2", + "num_adapter_layers": 3, + "num_attention_heads": 16, + "num_codevector_groups": 2, + "num_codevectors_per_group": 320, + "num_conv_pos_embedding_groups": 16, + "num_conv_pos_embeddings": 128, + "num_feat_extract_layers": 7, + "num_hidden_layers": 24, + "num_negatives": 100, + "output_hidden_size": 1024, + "pad_token_id": 29, + "proj_codevector_dim": 768, + "tdnn_dilation": [ + 1, + 2, + 3, + 1, + 1 + ], + "tdnn_dim": [ + 512, + 512, + 512, + 512, + 1500 + ], + "tdnn_kernel": [ + 5, + 3, + 3, + 1, + 1 + ], + "torch_dtype": "float32", + "transformers_version": "4.17.0.dev0", + "use_weighted_layer_sum": false, + "vocab_size": 32, + "xvector_output_dim": 512 +} diff --git a/eval.py b/eval.py new file mode 100644 index 0000000..88bfe8f --- /dev/null +++ b/eval.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +import argparse +import re +from typing import Dict + +import torch +from datasets import Audio, Dataset, load_dataset, load_metric + +from transformers import AutoFeatureExtractor, pipeline + + +def log_results(result: Dataset, args: Dict[str, str]): + """DO NOT CHANGE. This function computes and logs the result metrics.""" + + log_outputs = args.log_outputs + dataset_id = "_".join(args.dataset.split("/") + [args.config, args.split]) + + # load metric + wer = load_metric("wer") + cer = load_metric("cer") + + # compute metrics + wer_result = wer.compute(references=result["target"], predictions=result["prediction"]) + cer_result = cer.compute(references=result["target"], predictions=result["prediction"]) + + # print & log results + result_str = f"WER: {wer_result}\n" f"CER: {cer_result}" + print(result_str) + + with open(f"{dataset_id}_eval_results.txt", "w") as f: + f.write(result_str) + + # log all results in text file. Possibly interesting for analysis + if log_outputs is not None: + pred_file = f"log_{dataset_id}_predictions.txt" + target_file = f"log_{dataset_id}_targets.txt" + + with open(pred_file, "w") as p, open(target_file, "w") as t: + + # mapping function to write output + def write_to_file(batch, i): + p.write(f"{i}" + "\n") + p.write(batch["prediction"] + "\n") + t.write(f"{i}" + "\n") + t.write(batch["target"] + "\n") + + result.map(write_to_file, with_indices=True) + +def remove_niqqud(string: str) -> str: + return ''.join('' if 1456 <= ord(c) <= 1479 else c for c in string) + +def normalize_text(text: str) -> str: + """DO ADAPT FOR YOUR USE CASE. this function normalizes the target text.""" + + chars_to_ignore_regex = '[,?.!\-\;\:"“%‘”�—’…–]' # noqa: W605 IMPORTANT: this should correspond to the chars that were ignored during training + text = re.sub(chars_to_ignore_regex, "", text.lower()) + text = remove_niqqud(text) + + # In addition, we can normalize the target text, e.g. removing new lines characters etc... + # note that order is important here! + token_sequences_to_ignore = ["\n\n", "\n", " ", " "] + + for t in token_sequences_to_ignore: + text = " ".join(text.split(t)) + + return text + + +def main(args): + # load dataset + dataset = load_dataset(args.dataset, args.config, split=args.split, use_auth_token=True) + + # for testing: only process the first two examples as a test + # dataset = dataset.select(range(10)) + + # load processor + feature_extractor = AutoFeatureExtractor.from_pretrained(args.model_id) + sampling_rate = feature_extractor.sampling_rate + + # resample audio + dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate)) + + # load eval pipeline + if args.device is None: + args.device = 0 if torch.cuda.is_available() else -1 + asr = pipeline("automatic-speech-recognition", model=args.model_id, device=args.device) + + # map function to decode audio + def map_to_pred(batch): + prediction = asr( + batch["audio"]["array"], chunk_length_s=args.chunk_length_s, stride_length_s=args.stride_length_s + ) + + batch["prediction"] = prediction["text"] + batch["target"] = normalize_text(batch["sentence"]) + return batch + + # run inference on all examples + result = dataset.map(map_to_pred, remove_columns=dataset.column_names) + + # compute and log_results + # do not change function below + log_results(result, args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--model_id", type=str, required=True, help="Model identifier. Should be loadable with 🤗 Transformers" + ) + parser.add_argument( + "--dataset", + type=str, + required=True, + help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets", + ) + parser.add_argument( + "--config", type=str, required=True, help="Config of the dataset. *E.g.* `'en'` for Common Voice" + ) + parser.add_argument("--split", type=str, required=True, help="Split of the dataset. *E.g.* `'test'`") + parser.add_argument( + "--chunk_length_s", type=float, default=None, help="Chunk length in seconds. Defaults to 5 seconds." + ) + parser.add_argument( + "--stride_length_s", type=float, default=None, help="Stride of the audio chunks. Defaults to 1 second." + ) + parser.add_argument( + "--log_outputs", action="store_true", help="If defined, write outputs to log file for analysis." + ) + parser.add_argument( + "--device", + type=int, + default=None, + help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.", + ) + args = parser.parse_args() + + main(args) \ No newline at end of file diff --git a/model.safetensors b/model.safetensors new file mode 100644 index 0000000..b1e155c --- /dev/null +++ b/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:75c4e535a62575ac57cf374dec17fb72fbab973dbd69d0391e47918880cf4e00 +size 1261938632 diff --git a/preprocessor_config.json b/preprocessor_config.json new file mode 100644 index 0000000..73caa15 --- /dev/null +++ b/preprocessor_config.json @@ -0,0 +1,9 @@ +{ + "do_normalize": true, + "feature_extractor_type": "Wav2Vec2FeatureExtractor", + "feature_size": 1, + "padding_side": "right", + "padding_value": 0.0, + "return_attention_mask": true, + "sampling_rate": 16000 +} diff --git a/pytorch_model.bin b/pytorch_model.bin new file mode 100644 index 0000000..338d9d7 --- /dev/null +++ b/pytorch_model.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5f4ee7fbf1fd8b43c413809f06de6b07930c451e471decd5d7219a2704dee64c +size 1262054897 diff --git a/run_train.py b/run_train.py new file mode 100644 index 0000000..07b1260 --- /dev/null +++ b/run_train.py @@ -0,0 +1,981 @@ +# !/usr/bin/env python +# coding=utf-8 +import functools +import json +import logging +import os +import re +import sys +import warnings +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Union + +import datasets +import numpy as np +import torch +import torchaudio +from datasets import DatasetDict, ReadInstruction, load_dataset, load_metric, concatenate_datasets + +try: + import bitsandbytes as bnb + + BNB_AVAILABLE = True +except: + BNB_AVAILABLE = False +try: + import wandb + + WANDB_AVAILABLE = True +except: + WANDB_AVAILABLE = False +import transformers +from transformers import ( + AutoConfig, + AutoFeatureExtractor, + AutoModelForCTC, + AutoTokenizer, + HfArgumentParser, + Trainer, + TrainerCallback, TrainingArguments, + Wav2Vec2Processor, + set_seed, +) + +try: + from torch_audiomentations import ( + Compose, + AddGaussianNoise, + AddGaussianSNR, + ClippingDistortion, + FrequencyMask, + Gain, + LoudnessNormalization, + Normalize, + PitchShift, + PolarityInversion, + Shift, + TimeMask, + TimeStretch, + ) + + AUDIOMENTATIONS_AVAILABLE = True +except: + AUDIOMENTATIONS_AVAILABLE = False +try: + from transformers import AutoProcessor +except: + pass +from transformers.trainer_pt_utils import get_parameter_names +from transformers.trainer_utils import get_last_checkpoint, is_main_process +from transformers.utils import check_min_version +from transformers.utils.versions import require_version + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.16.0") + +require_version( + "datasets>=1.13.3", + "To fix: pip install -r examples/pytorch/text-classification/requirements.txt", +) + +logger = logging.getLogger(__name__) + + +def list_field(default=None, metadata=None): + return field(default_factory=lambda: default, metadata=metadata) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={ + "help": "Path to pretrained model or model identifier from huggingface.co/models" + } + ) + tokenizer_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "Path to pretrained tokenizer or tokenizer identifier from huggingface.co/models" + }, + ) + cache_dir: Optional[str] = field( + default=None, + metadata={ + "help": "Where do you want to store the pretrained models downloaded from huggingface.co" + }, + ) + freeze_feature_encoder: bool = field( + default=True, + metadata={"help": "Whether to freeze the feature encoder layers of the model."}, + ) + attention_dropout: float = field( + default=0.0, + metadata={"help": "The dropout ratio for the attention probabilities."}, + ) + activation_dropout: float = field( + default=0.0, + metadata={ + "help": "The dropout ratio for activations inside the fully connected layer." + }, + ) + feat_proj_dropout: float = field( + default=0.0, metadata={"help": "The dropout ratio for the projected features."} + ) + hidden_dropout: float = field( + default=0.0, + metadata={ + "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler." + }, + ) + final_dropout: float = field( + default=0.0, + metadata={"help": "The dropout probability for the final projection layer."}, + ) + mask_time_prob: float = field( + default=0.05, + metadata={ + "help": "Probability of each feature vector along the time axis to be chosen as the start of the vector" + "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature" + "vectors will be masked along the time axis." + }, + ) + mask_time_length: int = field( + default=10, + metadata={"help": "Length of vector span to mask along the time axis."}, + ) + mask_feature_prob: float = field( + default=0.0, + metadata={ + "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector" + "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis." + }, + ) + mask_feature_length: int = field( + default=10, + metadata={"help": "Length of vector span to mask along the feature axis."}, + ) + layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."}) + ctc_loss_reduction: Optional[str] = field( + default="mean", + metadata={ + "help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'." + }, + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + + Using `HfArgumentParser` we can turn this class + into argparse arguments to be able to specify them on + the command line. + """ + + dataset_path: str = field( + default=None, + metadata={ + "help": "The configuration name of the dataset to use (via the datasets library)." + } + ) + dataset_name: str = field( + default=None, + metadata={ + "help": "The configuration name of the dataset to use (via the datasets library)." + }, + ) + dataset_config_name: str = field( + default=None, + metadata={ + "help": "The configuration name of the dataset to use (via the datasets library)." + }, + ) + train_split_name: str = field( + default="train", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + eval_split_name: str = field( + default="validation", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + audio_column_name: str = field( + default="audio", + metadata={ + "help": "The name of the dataset column containing the audio data. Defaults to 'audio'" + }, + ) + text_column_name: str = field( + default="text", + metadata={ + "help": "The name of the dataset column containing the text data. Defaults to 'text'" + }, + ) + wav_filesize_column_name: str = field( + default=None, + metadata={ + "help": "The name of the dataset column containing the wav filesize. Defaults is None" + }, + ) + overwrite_cache: bool = field( + default=False, + metadata={"help": "Overwrite the cached preprocessed datasets or not."}, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of validation examples to this " + "value if set." + }, + ) + chars_to_ignore: Optional[List[str]] = list_field( + default=None, + metadata={"help": "A list of characters to remove from the transcripts."}, + ) + eval_metrics: List[str] = list_field( + default=["wer"], + metadata={ + "help": "A list of metrics the model should be evaluated on. E.g. `'wer cer'`" + }, + ) + max_duration_in_seconds: float = field( + default=20.0, + metadata={ + "help": "Filter audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`" + }, + ) + min_duration_in_seconds: float = field( + default=0.0, + metadata={ + "help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds" + }, + ) + preprocessing_only: bool = field( + default=False, + metadata={ + "help": "Whether to only do data preprocessing and skip training. " + "This is especially useful when data preprocessing errors out in distributed training due to timeout. " + "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` " + "so that the cached datasets can consequently be loaded in distributed training" + }, + ) + print_samples: bool = field( + default=False, + metadata={ + "help": "Print row with validation inference results to stdout after each epoch" + }, + ) + use_augmentations: bool = field( + default=False, + metadata={ + "help": "Use data augmentation during training" + }, + ) + use_auth_token: str = field( + default="", + metadata={ + "help": "If :obj:`True`, will use the token generated when running" + ":obj:`transformers-cli login` as HTTP bearer authorization for remote files." + }, + ) + unk_token: str = field( + default="[UNK]", + metadata={"help": "The unk token for the tokenizer"}, + ) + pad_token: str = field( + default="[PAD]", + metadata={"help": "The padding token for the tokenizer"}, + ) + word_delimiter_token: str = field( + default="|", + metadata={"help": "The word delimiter token for the tokenizer"}, + ) + phoneme_language: Optional[str] = field( + default=None, + metadata={ + "help": "The target language that should be used be" + " passed to the tokenizer for tokenization. Note that" + " this is only relevant if the model classifies the" + " input audio to a sequence of phoneme sequences." + }, + ) + + +class Augmentator: + + def __init__( + self, + apply_gaussian_noise_with_p=0.1, + apply_gain_with_p=0.1, + apply_pitch_shift_with_p=0.1, + apply_time_stretch_with_p=0.1, + augment_proba=0.1, + sample_rate=16_000 + ): + self.augmentator_fn = None + self.sample_rate = sample_rate + self.augment_proba = augment_proba + all_p = ( + apply_gaussian_noise_with_p + + apply_gain_with_p + + apply_pitch_shift_with_p + + apply_time_stretch_with_p + ) + if AUDIOMENTATIONS_AVAILABLE and all_p > 0: + self.augmentator_fn = Compose([ + TimeStretch(min_rate=0.8, max_rate=1.2, leave_length_unchanged=False, + p=apply_time_stretch_with_p), + PitchShift(min_semitones=-1, max_semitones=1, + p=apply_pitch_shift_with_p), + Gain(min_gain_in_db=-1, max_gain_in_db=1, p=apply_gain_with_p), + AddGaussianNoise(min_amplitude=0.0001, max_amplitude=0.001, + p=apply_gaussian_noise_with_p), + ]) + + def __call__(self, input_values: List[float], *args, **kwargs): + if AUDIOMENTATIONS_AVAILABLE and self.augmentator_fn is not None: + return self.augmentator_fn(samples=np.array(input_values), + sample_rate=self.sample_rate).tolist() + else: + return input_values + + +@dataclass +class DataCollatorCTCWithPadding: + """ + Data collator that will dynamically pad the inputs received. + Args: + processor (:class:`~transformers.AutoProcessor`) + The processor used for proccessing the data. + padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding index) + among: + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + max_length (:obj:`int`, `optional`): + Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). + max_length_labels (:obj:`int`, `optional`): + Maximum length of the ``labels`` returned list and optionally padding length (see above). + pad_to_multiple_of (:obj:`int`, `optional`): + If set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + """ + + processor: 'AutoProcessor' + padding: Union[bool, str] = "longest" + pad_to_multiple_of: Optional[int] = None + pad_to_multiple_of_labels: Optional[int] = None + augmentator_fn: Optional[Callable] = None + use_augmentations: bool = False + + def __call__( + self, features: List[Dict[str, Union[List[int], torch.Tensor]]] + ) -> Dict[str, torch.Tensor]: + # split inputs and labels since they have to be of different lenghts and need + # different padding methods + input_features = [ + { + "input_values": self.augmentator_fn(feature["input_values"]) + if self.use_augmentations + else feature["input_values"]} + for feature in features + ] + label_features = [{"input_ids": feature["labels"]} for feature in features] + + batch = self.processor.pad( + input_features, + padding=self.padding, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors="pt", + ) + + with self.processor.as_target_processor(): + labels_batch = self.processor.pad( + label_features, + padding=self.padding, + pad_to_multiple_of=self.pad_to_multiple_of_labels, + return_tensors="pt", + ) + + # replace padding with -100 to ignore loss correctly + labels = labels_batch["input_ids"].masked_fill( + labels_batch.attention_mask.ne(1), -100 + ) + + batch["labels"] = labels + + return batch + + +def create_vocabulary_from_data( + datasets: DatasetDict, + text_column_name: str, + train_split_name: str, + word_delimiter_token: Optional[str] = None, + unk_token: Optional[str] = None, + pad_token: Optional[str] = None, +): + # Given training and test labels create vocabulary + def extract_all_chars(batch): + all_text = " ".join(batch[text_column_name]) + vocab = list(set(all_text)) + return {"vocab": [vocab], "all_text": [all_text]} + + print("extract chars") + vocabs = datasets.map( + extract_all_chars, + batched=True, + batch_size=-1, + keep_in_memory=True, + remove_columns=datasets[train_split_name].column_names, + ) + + # take union of all unique characters in each dataset + print("make vocab_set") + vocab_set = functools.reduce( + lambda vocab_1, vocab_2: set(vocab_1["vocab"][0]) | set(vocab_2["vocab"][0]), + vocabs.values(), + ) + + vocab_dict = {v: k for k, v in enumerate(sorted(list(vocab_set)))} + + # replace white space with delimiter token + if word_delimiter_token is not None: + vocab_dict[word_delimiter_token] = vocab_dict[" "] + del vocab_dict[" "] + + # add unk and pad token + if unk_token is not None: + vocab_dict[unk_token] = len(vocab_dict) + + if pad_token is not None: + vocab_dict[pad_token] = len(vocab_dict) + + return vocab_dict + + +def speech_file_to_array_fn(batch, audio_column_name, dataset_path=""): + if dataset_path: + dataset_path = os.path.join(dataset_path, batch[audio_column_name]) + else: + dataset_path = batch[audio_column_name] if isinstance(batch[audio_column_name], + str) else \ + batch[audio_column_name]["path"] + speech_array, sampling_rate = torchaudio.load(dataset_path) + batch[audio_column_name] = { + "array": speech_array[0].numpy(), + "sampling_rate": sampling_rate, + } + return batch + + +class PrintSamplesPredictionCallback(TrainerCallback): + + def __init__(self, processor, eval_dataset): + super(PrintSamplesPredictionCallback, self).__init__() + self.processor = processor + self.eval_dataset = eval_dataset + self.metric_fn = load_metric("wer") + + def on_log( + self, + args: Any, + state: Any, + control: Any, + model: Any, + logs: Optional[Any] = None, + **kwargs + ): + """ + :param args: + :param state: + :param control: + :param model: + :param logs: + :param kwargs: 'tokenizer', 'optimizer', 'lr_scheduler', 'train_dataloader', 'eval_dataloader' + :return: + """ + if state.is_local_process_zero: + columns = ["id", "prediction", "reference", "audio", "wer"] + data = [] + for idx, row in enumerate(self.eval_dataset): + input_dict = self.processor(row["input_values"], + return_tensors="pt", padding=True) + logits = model(input_dict.input_values.to(model.device)).logits + pred_ids = torch.argmax(logits, dim=-1)[0] + prediction = self.processor.decode(pred_ids) + print(f"Prediction: {prediction}") + reference = row['references'].lower() + print(f"\nReference: {reference}") + + if WANDB_AVAILABLE: + + audio, sample_rate = tuple(row["audio"].values()) + audio = wandb.Audio(np.squeeze(audio), + sample_rate=sample_rate) + wer = self.metric_fn.compute( + predictions=[prediction], + references=[reference], + ) + + data.append([idx, prediction, reference, audio, wer]) + if WANDB_AVAILABLE: + table = wandb.Table(data=data, columns=columns) + wandb.run.log({"audio_predictions": table}) + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file( + json_file=os.path.abspath(sys.argv[1]) + ) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Detecting last checkpoint. + last_checkpoint = None + if ( + os.path.isdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + logger.setLevel( + logging.INFO if is_main_process(training_args.local_rank) else logging.WARN + ) + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + # Set the verbosity to info of the Transformers logger (on main process only): + if is_main_process(training_args.local_rank): + transformers.utils.logging.set_verbosity_info() + logger.info("Training/evaluation parameters %s", training_args) + + # Set seed before initializing model. + set_seed(training_args.seed) + + train_split_name = data_args.train_split_name + eval_split_name = data_args.eval_split_name + + # 1. First, let's load the dataset + raw_datasets = DatasetDict({ + train_split_name: None, + eval_split_name: None, + }) + + if data_args.dataset_path: + raw_datasets = load_dataset( + "csv", + data_files={ + train_split_name: os.path.join(data_args.dataset_path, "train-all.csv"), + eval_split_name: os.path.join(data_args.dataset_path, "eval-all.csv"), + }, + ) + + if training_args.do_train: + if raw_datasets[train_split_name] is None: + raw_datasets[train_split_name] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=data_args.train_split_name, + use_auth_token=data_args.use_auth_token, + ) + + if data_args.audio_column_name not in raw_datasets[train_split_name].column_names: + raise ValueError( + f"--audio_column_name '{data_args.audio_column_name}' not found in dataset. " + "Make sure to set `--audio_column_name` to the correct audio column - one of " + f"{', '.join(raw_datasets['train'].column_names)}." + ) + + if data_args.text_column_name not in raw_datasets[train_split_name].column_names: + raise ValueError( + f"--text_column_name {data_args.text_column_name} not found in dataset. " + "Make sure to set `--text_column_name` to the correct text column - one of " + f"{', '.join(raw_datasets['train'].column_names)}." + ) + + if data_args.max_train_samples is not None: + raw_datasets[train_split_name] = raw_datasets[train_split_name].select( + range(data_args.max_train_samples) + ) + + if data_args.wav_filesize_column_name is not None: + raw_datasets[train_split_name] = raw_datasets[train_split_name].sort( + data_args.wav_filesize_column_name, reverse=True) + + if training_args.do_eval: + if raw_datasets[eval_split_name] is None: + raw_datasets[eval_split_name] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=data_args.eval_split_name, + use_auth_token=data_args.use_auth_token, + ) + + if data_args.max_eval_samples is not None: + raw_datasets[eval_split_name] = raw_datasets[eval_split_name].select( + range(data_args.max_eval_samples) + ) + if data_args.wav_filesize_column_name is not None: + raw_datasets[eval_split_name] = raw_datasets[eval_split_name].sort( + data_args.wav_filesize_column_name, reverse=True) + + # save special tokens for tokenizer + word_delimiter_token = data_args.word_delimiter_token + unk_token = data_args.unk_token + pad_token = data_args.pad_token + + # 3. Next, let's load the config as we might need it to create + # the tokenizer + # load config + config = AutoConfig.from_pretrained( + model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_auth_token=data_args.use_auth_token, + ) + + # 4. Next, if no tokenizer file is defined, + # we create the vocabulary of the model by extracting all unique characters from + # the training and evaluation datasets + # We need to make sure that only first rank saves vocabulary + # make sure all processes wait until vocab is created + tokenizer_name_or_path = model_args.tokenizer_name_or_path + tokenizer_kwargs = {} + + # 5. Now we can instantiate the feature extractor, tokenizer and model + # Note for distributed training, the .from_pretrained methods guarantee that only + # one local process can concurrently download model & vocab. + with open(os.path.join(tokenizer_name_or_path, "vocab.json"), "r") as fin: + print("loading tokenizer") + print(fin.read()) + + # load feature_extractor and tokenizer + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name_or_path, + use_auth_token=data_args.use_auth_token, + **tokenizer_kwargs, + ) + feature_extractor = AutoFeatureExtractor.from_pretrained( + model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_auth_token=data_args.use_auth_token, + ) + + # adapt config + config.update( + { + "feat_proj_dropout": model_args.feat_proj_dropout, + "attention_dropout": model_args.attention_dropout, + "hidden_dropout": model_args.hidden_dropout, + "final_dropout": model_args.final_dropout, + "mask_time_prob": model_args.mask_time_prob, + "mask_time_length": model_args.mask_time_length, + "mask_feature_prob": model_args.mask_feature_prob, + "mask_feature_length": model_args.mask_feature_length, + "gradient_checkpointing": training_args.gradient_checkpointing, + "layerdrop": model_args.layerdrop, + "ctc_loss_reduction": model_args.ctc_loss_reduction, + "pad_token_id": tokenizer.pad_token_id, + "vocab_size": len(tokenizer), + "activation_dropout": model_args.activation_dropout, + } + ) + + # create model + model = AutoModelForCTC.from_pretrained( + model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + config=config, + use_auth_token=data_args.use_auth_token, + ) + + # freeze encoder + if model_args.freeze_feature_encoder: + model.freeze_feature_encoder() + + # 6. Now we preprocess the datasets including loading the audio, resampling and normalization + # Thankfully, `datasets` takes care of automatically loading and resampling the audio, + # so that we just need to set the correct target sampling rate and normalize the input + # via the `feature_extractor` + + # make sure that dataset decodes audio with correct sampling rate + + # derive max & min input length for sample rate & max duration + audio_column_name = data_args.audio_column_name + num_workers = data_args.preprocessing_num_workers + + # `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification + phoneme_language = data_args.phoneme_language + + raw_datasets[train_split_name] = raw_datasets[train_split_name].map( + speech_file_to_array_fn, + num_proc=num_workers, + fn_kwargs={"dataset_path": data_args.dataset_path, + "audio_column_name": audio_column_name}, + ) + raw_datasets[eval_split_name] = raw_datasets[eval_split_name].map( + speech_file_to_array_fn, + num_proc=num_workers, + fn_kwargs={"dataset_path": data_args.dataset_path, + "audio_column_name": audio_column_name}, + ) + + # Preprocessing the datasets. + # We need to read the audio files as arrays and tokenize the targets. + def prepare_dataset(batch): + # load audio + sample = batch[audio_column_name] + + inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"]) + batch["input_values"] = inputs.input_values[0] + batch["input_length"] = len(batch["input_values"]) + + # encode targets + additional_kwargs = {} + if phoneme_language is not None: + additional_kwargs["phonemizer_lang"] = phoneme_language + + batch["labels"] = tokenizer(batch[data_args.text_column_name], + **additional_kwargs).input_ids + return batch + + print(f"Vectorizing") + + with training_args.main_process_first(desc="dataset map preprocessing"): + vectorized_datasets = raw_datasets.map( + prepare_dataset, + remove_columns=next(iter(raw_datasets.values())).column_names, + num_proc=num_workers, + desc="preprocess datasets", + ) + + # 7. Next, we can prepare the training. + # Let's use word error rate (WER) as our evaluation metric, + # instantiate a data collator and the trainer + + # Define evaluation metrics during training, *i.e.* word error rate, character error rate + eval_metrics = {metric: load_metric(metric) for metric in data_args.eval_metrics} + + # for large datasets it is advised to run the preprocessing on a + # single machine first with ``args.preprocessing_only`` since there will mostly likely + # be a timeout when running the script in distributed mode. + # In a second step ``args.preprocessing_only`` can then be set to `False` to load the + # cached dataset + if data_args.preprocessing_only: + logger.info( + f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}" + ) + return + + def compute_metrics(pred): + pred_logits = pred.predictions + pred_ids = np.argmax(pred_logits, axis=-1) + + pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id + + pred_str = tokenizer.batch_decode(pred_ids) + # we do not want to group tokens when computing the metrics + label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False) + + metrics = { + k: v.compute(predictions=pred_str, references=label_str) + for k, v in eval_metrics.items() + } + + return metrics + + # Now save everything to be able to create a single processor later + if is_main_process(training_args.local_rank): + # save feature extractor, tokenizer and config + feature_extractor.save_pretrained(training_args.output_dir) + tokenizer.save_pretrained(training_args.output_dir) + config.save_pretrained(training_args.output_dir) + + try: + processor = AutoProcessor.from_pretrained(training_args.output_dir) + except (OSError, KeyError): + warnings.warn( + "Loading a processor from a feature extractor config that does not" + " include a `processor_class` attribute is deprecated and will be removed in v5. Please add the following " + " attribute to your `preprocessor_config.json` file to suppress this warning: " + " `'processor_class': 'Wav2Vec2Processor'`", + FutureWarning, + ) + processor = Wav2Vec2Processor.from_pretrained(training_args.output_dir) + + # Instantiate custom data collator + data_collator = DataCollatorCTCWithPadding( + processor=processor, + augmentator_fn=Augmentator(), + use_augmentations=data_args.use_augmentations + ) + + decay_parameters = get_parameter_names(model, [torch.nn.LayerNorm]) + decay_parameters = [name for name in decay_parameters if "bias" not in name] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if n in decay_parameters], + "weight_decay": training_args.weight_decay, + }, + { + "params": [ + p for n, p in model.named_parameters() if n not in decay_parameters + ], + "weight_decay": 0.0, + }, + ] + trainer_kwargs = {} + if BNB_AVAILABLE: + optimizer = bnb.optim.Adam8bit( + params=optimizer_grouped_parameters, + betas=(training_args.adam_beta1, training_args.adam_beta2), + eps=training_args.adam_epsilon, + ) + trainer_kwargs["optimizers"] = (optimizer, None) + + samples_to_log = [ + { + **vectorized_datasets[eval_split_name][i], + "references": raw_datasets[eval_split_name][i][data_args.text_column_name], + "audio": raw_datasets[eval_split_name][i][data_args.audio_column_name], + } for i in range(5) + ] + + trainer = Trainer( + model=model, + data_collator=data_collator, + args=training_args, + compute_metrics=compute_metrics, + train_dataset=vectorized_datasets[ + train_split_name] if training_args.do_train else None, + eval_dataset=vectorized_datasets[ + eval_split_name] if training_args.do_eval else None, + tokenizer=feature_extractor, + **trainer_kwargs, + callbacks=[PrintSamplesPredictionCallback( + processor=processor, + eval_dataset=samples_to_log)] if data_args.print_samples and training_args.do_eval else None, + ) + + # 8. Finally, we can start training + + # Training + if training_args.do_train: + + # use last checkpoint if exist + if last_checkpoint is not None: + checkpoint = last_checkpoint + elif os.path.isdir(model_args.model_name_or_path): + checkpoint = model_args.model_name_or_path + else: + checkpoint = None + + train_result = trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_model() + + metrics = train_result.metrics + max_train_samples = ( + data_args.max_train_samples + if data_args.max_train_samples is not None + else len(vectorized_datasets[train_split_name]) + ) + metrics["train_samples"] = min( + max_train_samples, len(vectorized_datasets[train_split_name]) + ) + + trainer.log_metrics(train_split_name, metrics) + trainer.save_metrics(train_split_name, metrics) + trainer.save_state() + + # Evaluation + results = {} + if training_args.do_eval: + logger.info("*** Evaluate ***") + metrics = trainer.evaluate() + max_eval_samples = ( + data_args.max_eval_samples + if data_args.max_eval_samples is not None + else len(vectorized_datasets[eval_split_name]) + ) + metrics["eval_samples"] = min(max_eval_samples, + len(vectorized_datasets[eval_split_name])) + + trainer.log_metrics(eval_split_name, metrics) + trainer.save_metrics(eval_split_name, metrics) + + # Write model card and (optionally) push to hub + config_name = ( + data_args.dataset_config_name + if data_args.dataset_config_name is not None + else "na" + ) + kwargs = { + "language": "he", + "finetuned_from": model_args.model_name_or_path, + "tasks": "speech-recognition", + "tags": ["automatic-speech-recognition", "robust-speech-event", "he"], + "dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}", + } + + if training_args.push_to_hub: + trainer.push_to_hub(**kwargs) + else: + trainer.create_model_card(**kwargs) + + return results + + +if __name__ == "__main__": + main() diff --git a/run_train.sh b/run_train.sh new file mode 100644 index 0000000..9375ab4 --- /dev/null +++ b/run_train.sh @@ -0,0 +1,39 @@ +export CUDA_VISIBLE_DEVICES="0,1" + +python -m torch.distributed.launch --nproc_per_node=2 run_train.py \ + --dataset_name="imvladikon/hebrew_speech_???" \ + --use_auth_token="???" \ + --audio_column_name="audio" \ + --text_column_name="sentence" \ + --model_name_or_path="imvladikon/wav2vec2-xls-r-300m-hebrew" \ + --tokenizer_name_or_path="./wav2vec2-xls-r-300m-hebrew" \ + --output_dir="./wav2vec2-xls-r-300m-hebrew" \ + --overwrite_output_dir \ + --evaluation_strategy="steps" \ + --length_column_name="input_length" \ + --gradient_checkpointing \ + --fp16 \ + --group_by_length \ + --num_train_epochs="100" \ + --per_device_train_batch_size="8" \ + --per_device_eval_batch_size="8" \ + --gradient_accumulation_steps="4" \ + --learning_rate="3e-4" \ + --warmup_steps="1000" \ + --save_steps="1000" \ + --eval_steps="1000" \ + --preprocessing_num_workers="$(nproc)" \ + --logging_steps="2000" \ + --layerdrop="0.0" \ + --activation_dropout="0.1" \ + --save_total_limit="3" \ + --freeze_feature_encoder \ + --feat_proj_dropout="0.0" \ + --mask_time_prob="0.75" \ + --mask_time_length="10" \ + --mask_feature_prob="0.25" \ + --mask_feature_length="64" \ + --do_train --do_eval \ + --print_samples \ + --use_augmentations \ + --push_to_hub \ No newline at end of file diff --git a/special_tokens_map.json b/special_tokens_map.json new file mode 100644 index 0000000..258236c --- /dev/null +++ b/special_tokens_map.json @@ -0,0 +1 @@ +{"bos_token": "", "eos_token": "", "unk_token": "[UNK]", "pad_token": "[PAD]", "additional_special_tokens": [{"content": "", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}]} \ No newline at end of file diff --git a/tokenizer_config.json b/tokenizer_config.json new file mode 100644 index 0000000..2a4a5c3 --- /dev/null +++ b/tokenizer_config.json @@ -0,0 +1 @@ +{"unk_token": "[UNK]", "bos_token": "", "eos_token": "", "pad_token": "[PAD]", "do_lower_case": false, "word_delimiter_token": "|", "special_tokens_map_file": null, "tokenizer_file": null, "name_or_path": "./wav2vec2-xls-r-300m-a-hebrew", "tokenizer_class": "Wav2Vec2CTCTokenizer"} \ No newline at end of file diff --git a/train_results.json b/train_results.json new file mode 100644 index 0000000..3cef30e --- /dev/null +++ b/train_results.json @@ -0,0 +1,8 @@ +{ + "epoch": 60.0, + "train_loss": 1.0617500136590194, + "train_runtime": 191647.3893, + "train_samples": 90777, + "train_samples_per_second": 28.42, + "train_steps_per_second": 0.444 +} \ No newline at end of file diff --git a/trainer_state.json b/trainer_state.json new file mode 100644 index 0000000..3d39585 --- /dev/null +++ b/trainer_state.json @@ -0,0 +1,1042 @@ +{ + "best_metric": null, + "best_model_checkpoint": null, + "epoch": 59.999647514980616, + "global_step": 85080, + "is_hyper_param_search": false, + "is_local_process_zero": true, + "is_world_process_zero": true, + "log_history": [ + { + "epoch": 0.7, + "eval_loss": 0.5371278524398804, + "eval_runtime": 423.1427, + "eval_samples_per_second": 47.847, + "eval_steps_per_second": 2.992, + "eval_wer": 0.3810800894541897, + "step": 1000 + }, + { + "epoch": 1.41, + "learning_rate": 0.0009881184586108469, + "loss": 1.3606, + "step": 2000 + }, + { + "epoch": 1.41, + "eval_loss": 0.5247300267219543, + "eval_runtime": 419.7941, + "eval_samples_per_second": 48.228, + "eval_steps_per_second": 3.016, + "eval_wer": 0.3902264309176043, + "step": 2000 + }, + { + "epoch": 2.12, + "eval_loss": 0.512639582157135, + "eval_runtime": 421.2981, + "eval_samples_per_second": 48.056, + "eval_steps_per_second": 3.005, + "eval_wer": 0.38586728632329304, + "step": 3000 + }, + { + "epoch": 2.82, + "learning_rate": 0.0009643553758325404, + "loss": 1.3671, + "step": 4000 + }, + { + "epoch": 2.82, + "eval_loss": 0.5062423348426819, + "eval_runtime": 423.4242, + "eval_samples_per_second": 47.815, + "eval_steps_per_second": 2.99, + "eval_wer": 0.3827660912712279, + "step": 4000 + }, + { + "epoch": 3.53, + "eval_loss": 0.49790534377098083, + "eval_runtime": 423.631, + "eval_samples_per_second": 47.792, + "eval_steps_per_second": 2.988, + "eval_wer": 0.3671902299252219, + "step": 5000 + }, + { + "epoch": 4.23, + "learning_rate": 0.0009405803996194102, + "loss": 1.3421, + "step": 6000 + }, + { + "epoch": 4.23, + "eval_loss": 0.4905886650085449, + "eval_runtime": 422.3048, + "eval_samples_per_second": 47.942, + "eval_steps_per_second": 2.998, + "eval_wer": 0.38159549933608217, + "step": 6000 + }, + { + "epoch": 4.94, + "eval_loss": 0.4783899188041687, + "eval_runtime": 421.2427, + "eval_samples_per_second": 48.063, + "eval_steps_per_second": 3.005, + "eval_wer": 0.36512859039765183, + "step": 7000 + }, + { + "epoch": 5.64, + "learning_rate": 0.0009168054234062797, + "loss": 1.328, + "step": 8000 + }, + { + "epoch": 5.64, + "eval_loss": 0.48102879524230957, + "eval_runtime": 420.6976, + "eval_samples_per_second": 48.125, + "eval_steps_per_second": 3.009, + "eval_wer": 0.3669106855825005, + "step": 8000 + }, + { + "epoch": 6.35, + "eval_loss": 0.47466185688972473, + "eval_runtime": 421.0932, + "eval_samples_per_second": 48.08, + "eval_steps_per_second": 3.006, + "eval_wer": 0.35974736180026556, + "step": 9000 + }, + { + "epoch": 7.05, + "learning_rate": 0.0008930185537583254, + "loss": 1.3109, + "step": 10000 + }, + { + "epoch": 7.05, + "eval_loss": 0.4812825620174408, + "eval_runtime": 417.8509, + "eval_samples_per_second": 48.453, + "eval_steps_per_second": 3.03, + "eval_wer": 0.3808267523935984, + "step": 10000 + }, + { + "epoch": 7.76, + "eval_loss": 0.46314355731010437, + "eval_runtime": 421.3095, + "eval_samples_per_second": 48.055, + "eval_steps_per_second": 3.005, + "eval_wer": 0.3560696065413376, + "step": 11000 + }, + { + "epoch": 8.46, + "learning_rate": 0.0008692435775451952, + "loss": 1.2873, + "step": 12000 + }, + { + "epoch": 8.46, + "eval_loss": 0.4602561295032501, + "eval_runtime": 419.2403, + "eval_samples_per_second": 48.292, + "eval_steps_per_second": 3.02, + "eval_wer": 0.3430533230833741, + "step": 12000 + }, + { + "epoch": 9.17, + "eval_loss": 0.4578526020050049, + "eval_runtime": 418.3262, + "eval_samples_per_second": 48.398, + "eval_steps_per_second": 3.026, + "eval_wer": 0.3532916346355441, + "step": 13000 + }, + { + "epoch": 9.87, + "learning_rate": 0.0008454567078972407, + "loss": 1.2661, + "step": 14000 + }, + { + "epoch": 9.87, + "eval_loss": 0.44712555408477783, + "eval_runtime": 420.1714, + "eval_samples_per_second": 48.185, + "eval_steps_per_second": 3.013, + "eval_wer": 0.336475295268712, + "step": 14000 + }, + { + "epoch": 10.58, + "eval_loss": 0.45836716890335083, + "eval_runtime": 415.4093, + "eval_samples_per_second": 48.737, + "eval_steps_per_second": 3.048, + "eval_wer": 0.3436910336152072, + "step": 15000 + }, + { + "epoch": 11.28, + "learning_rate": 0.0008216817316841104, + "loss": 1.249, + "step": 16000 + }, + { + "epoch": 11.28, + "eval_loss": 0.4460853040218353, + "eval_runtime": 419.1988, + "eval_samples_per_second": 48.297, + "eval_steps_per_second": 3.02, + "eval_wer": 0.3454294499965057, + "step": 16000 + }, + { + "epoch": 11.99, + "eval_loss": 0.44824984669685364, + "eval_runtime": 419.1569, + "eval_samples_per_second": 48.302, + "eval_steps_per_second": 3.02, + "eval_wer": 0.3367024250471731, + "step": 17000 + }, + { + "epoch": 12.69, + "learning_rate": 0.00079790675547098, + "loss": 1.2322, + "step": 18000 + }, + { + "epoch": 12.69, + "eval_loss": 0.44639432430267334, + "eval_runtime": 417.2494, + "eval_samples_per_second": 48.523, + "eval_steps_per_second": 3.034, + "eval_wer": 0.33347019358445734, + "step": 18000 + }, + { + "epoch": 13.4, + "eval_loss": 0.4426889717578888, + "eval_runtime": 417.2512, + "eval_samples_per_second": 48.522, + "eval_steps_per_second": 3.034, + "eval_wer": 0.3454469215179258, + "step": 19000 + }, + { + "epoch": 14.1, + "learning_rate": 0.0007741317792578497, + "loss": 1.22, + "step": 20000 + }, + { + "epoch": 14.1, + "eval_loss": 0.44404253363609314, + "eval_runtime": 419.2873, + "eval_samples_per_second": 48.287, + "eval_steps_per_second": 3.019, + "eval_wer": 0.33950660423509676, + "step": 20000 + }, + { + "epoch": 14.81, + "eval_loss": 0.44594520330429077, + "eval_runtime": 418.1805, + "eval_samples_per_second": 48.415, + "eval_steps_per_second": 3.027, + "eval_wer": 0.3378468097001887, + "step": 21000 + }, + { + "epoch": 15.51, + "learning_rate": 0.0007503449096098953, + "loss": 1.2044, + "step": 22000 + }, + { + "epoch": 15.51, + "eval_loss": 0.4406070411205292, + "eval_runtime": 418.2792, + "eval_samples_per_second": 48.403, + "eval_steps_per_second": 3.027, + "eval_wer": 0.3199035572017611, + "step": 22000 + }, + { + "epoch": 16.22, + "eval_loss": 0.4397943317890167, + "eval_runtime": 419.2827, + "eval_samples_per_second": 48.287, + "eval_steps_per_second": 3.019, + "eval_wer": 0.31545705500034943, + "step": 23000 + }, + { + "epoch": 16.92, + "learning_rate": 0.000726581826831589, + "loss": 1.1913, + "step": 24000 + }, + { + "epoch": 16.92, + "eval_loss": 0.42369282245635986, + "eval_runtime": 417.1191, + "eval_samples_per_second": 48.538, + "eval_steps_per_second": 3.035, + "eval_wer": 0.314959116639877, + "step": 24000 + }, + { + "epoch": 17.63, + "eval_loss": 0.4286690652370453, + "eval_runtime": 422.4918, + "eval_samples_per_second": 47.92, + "eval_steps_per_second": 2.997, + "eval_wer": 0.3278880424907401, + "step": 25000 + }, + { + "epoch": 18.34, + "learning_rate": 0.0007028068506184586, + "loss": 1.1705, + "step": 26000 + }, + { + "epoch": 18.34, + "eval_loss": 0.4252525866031647, + "eval_runtime": 419.581, + "eval_samples_per_second": 48.253, + "eval_steps_per_second": 3.017, + "eval_wer": 0.31029422042071425, + "step": 26000 + }, + { + "epoch": 19.04, + "eval_loss": 0.42342108488082886, + "eval_runtime": 420.8706, + "eval_samples_per_second": 48.105, + "eval_steps_per_second": 3.008, + "eval_wer": 0.30976133901740166, + "step": 27000 + }, + { + "epoch": 19.75, + "learning_rate": 0.0006790199809705044, + "loss": 1.1564, + "step": 28000 + }, + { + "epoch": 19.75, + "eval_loss": 0.4174347519874573, + "eval_runtime": 419.264, + "eval_samples_per_second": 48.289, + "eval_steps_per_second": 3.02, + "eval_wer": 0.3076472849255713, + "step": 28000 + }, + { + "epoch": 20.45, + "eval_loss": 0.42600032687187195, + "eval_runtime": 415.7754, + "eval_samples_per_second": 48.695, + "eval_steps_per_second": 3.045, + "eval_wer": 0.31604235096792227, + "step": 29000 + }, + { + "epoch": 21.16, + "learning_rate": 0.0006552450047573739, + "loss": 1.1461, + "step": 30000 + }, + { + "epoch": 21.16, + "eval_loss": 0.42350149154663086, + "eval_runtime": 419.5333, + "eval_samples_per_second": 48.258, + "eval_steps_per_second": 3.018, + "eval_wer": 0.30363757075966175, + "step": 30000 + }, + { + "epoch": 21.86, + "eval_loss": 0.43086713552474976, + "eval_runtime": 421.0996, + "eval_samples_per_second": 48.079, + "eval_steps_per_second": 3.006, + "eval_wer": 0.30550702355161086, + "step": 31000 + }, + { + "epoch": 22.57, + "learning_rate": 0.0006314581351094196, + "loss": 1.1285, + "step": 32000 + }, + { + "epoch": 22.57, + "eval_loss": 0.4263738989830017, + "eval_runtime": 418.0963, + "eval_samples_per_second": 48.424, + "eval_steps_per_second": 3.028, + "eval_wer": 0.3006324690754071, + "step": 32000 + }, + { + "epoch": 23.27, + "eval_loss": 0.420136034488678, + "eval_runtime": 418.959, + "eval_samples_per_second": 48.325, + "eval_steps_per_second": 3.022, + "eval_wer": 0.28796561604584525, + "step": 33000 + }, + { + "epoch": 23.98, + "learning_rate": 0.0006076950523311132, + "loss": 1.1135, + "step": 34000 + }, + { + "epoch": 23.98, + "eval_loss": 0.41308358311653137, + "eval_runtime": 419.6417, + "eval_samples_per_second": 48.246, + "eval_steps_per_second": 3.017, + "eval_wer": 0.29749633098050177, + "step": 34000 + }, + { + "epoch": 24.68, + "eval_loss": 0.42022430896759033, + "eval_runtime": 415.5643, + "eval_samples_per_second": 48.719, + "eval_steps_per_second": 3.046, + "eval_wer": 0.2848556852330701, + "step": 35000 + }, + { + "epoch": 25.39, + "learning_rate": 0.0005839200761179829, + "loss": 1.0968, + "step": 36000 + }, + { + "epoch": 25.39, + "eval_loss": 0.41045960783958435, + "eval_runtime": 416.7195, + "eval_samples_per_second": 48.584, + "eval_steps_per_second": 3.038, + "eval_wer": 0.2887867775525893, + "step": 36000 + }, + { + "epoch": 26.09, + "eval_loss": 0.4209502339363098, + "eval_runtime": 419.7061, + "eval_samples_per_second": 48.239, + "eval_steps_per_second": 3.016, + "eval_wer": 0.28344922775875325, + "step": 37000 + }, + { + "epoch": 26.8, + "learning_rate": 0.0005601332064700286, + "loss": 1.087, + "step": 38000 + }, + { + "epoch": 26.8, + "eval_loss": 0.4122714400291443, + "eval_runtime": 417.6093, + "eval_samples_per_second": 48.481, + "eval_steps_per_second": 3.032, + "eval_wer": 0.2843228038297575, + "step": 38000 + }, + { + "epoch": 27.5, + "eval_loss": 0.42156872153282166, + "eval_runtime": 415.6358, + "eval_samples_per_second": 48.711, + "eval_steps_per_second": 3.046, + "eval_wer": 0.2802694108602977, + "step": 39000 + }, + { + "epoch": 28.21, + "learning_rate": 0.0005363582302568982, + "loss": 1.0707, + "step": 40000 + }, + { + "epoch": 28.21, + "eval_loss": 0.4161020517349243, + "eval_runtime": 421.5813, + "eval_samples_per_second": 48.024, + "eval_steps_per_second": 3.003, + "eval_wer": 0.2786707666503599, + "step": 40000 + }, + { + "epoch": 28.91, + "eval_loss": 0.4186325967311859, + "eval_runtime": 419.2011, + "eval_samples_per_second": 48.297, + "eval_steps_per_second": 3.02, + "eval_wer": 0.27398839890977705, + "step": 41000 + }, + { + "epoch": 29.62, + "learning_rate": 0.0005125713606089438, + "loss": 1.0575, + "step": 42000 + }, + { + "epoch": 29.62, + "eval_loss": 0.41177135705947876, + "eval_runtime": 420.5227, + "eval_samples_per_second": 48.145, + "eval_steps_per_second": 3.011, + "eval_wer": 0.2844625760011182, + "step": 42000 + }, + { + "epoch": 30.32, + "eval_loss": 0.4242798388004303, + "eval_runtime": 415.2438, + "eval_samples_per_second": 48.757, + "eval_steps_per_second": 3.049, + "eval_wer": 0.27729925221888324, + "step": 43000 + }, + { + "epoch": 31.03, + "learning_rate": 0.0004888082778306375, + "loss": 1.0474, + "step": 44000 + }, + { + "epoch": 31.03, + "eval_loss": 0.4220682382583618, + "eval_runtime": 420.3236, + "eval_samples_per_second": 48.168, + "eval_steps_per_second": 3.012, + "eval_wer": 0.27072122440422114, + "step": 44000 + }, + { + "epoch": 31.73, + "eval_loss": 0.4138353765010834, + "eval_runtime": 416.304, + "eval_samples_per_second": 48.633, + "eval_steps_per_second": 3.041, + "eval_wer": 0.2699524774617374, + "step": 45000 + }, + { + "epoch": 32.44, + "learning_rate": 0.00046502140818268316, + "loss": 1.0333, + "step": 46000 + }, + { + "epoch": 32.44, + "eval_loss": 0.41024765372276306, + "eval_runtime": 419.6952, + "eval_samples_per_second": 48.24, + "eval_steps_per_second": 3.016, + "eval_wer": 0.26381997344328745, + "step": 46000 + }, + { + "epoch": 33.15, + "eval_loss": 0.4162220358848572, + "eval_runtime": 421.7395, + "eval_samples_per_second": 48.006, + "eval_steps_per_second": 3.002, + "eval_wer": 0.26496435809630303, + "step": 47000 + }, + { + "epoch": 33.85, + "learning_rate": 0.00044123453853472884, + "loss": 1.0191, + "step": 48000 + }, + { + "epoch": 33.85, + "eval_loss": 0.4154505133628845, + "eval_runtime": 415.0181, + "eval_samples_per_second": 48.783, + "eval_steps_per_second": 3.05, + "eval_wer": 0.2636452582290866, + "step": 48000 + }, + { + "epoch": 34.56, + "eval_loss": 0.41287967562675476, + "eval_runtime": 419.9379, + "eval_samples_per_second": 48.212, + "eval_steps_per_second": 3.015, + "eval_wer": 0.2655933328674261, + "step": 49000 + }, + { + "epoch": 35.26, + "learning_rate": 0.0004174595623215985, + "loss": 1.0087, + "step": 50000 + }, + { + "epoch": 35.26, + "eval_loss": 0.4157230257987976, + "eval_runtime": 417.0542, + "eval_samples_per_second": 48.545, + "eval_steps_per_second": 3.036, + "eval_wer": 0.2631909986721644, + "step": 50000 + }, + { + "epoch": 35.97, + "eval_loss": 0.40904876589775085, + "eval_runtime": 417.7439, + "eval_samples_per_second": 48.465, + "eval_steps_per_second": 3.031, + "eval_wer": 0.26537493884967506, + "step": 51000 + }, + { + "epoch": 36.67, + "learning_rate": 0.0003936845861084681, + "loss": 0.9901, + "step": 52000 + }, + { + "epoch": 36.67, + "eval_loss": 0.41830816864967346, + "eval_runtime": 417.9487, + "eval_samples_per_second": 48.441, + "eval_steps_per_second": 3.029, + "eval_wer": 0.25867461038507233, + "step": 52000 + }, + { + "epoch": 37.38, + "eval_loss": 0.4250655770301819, + "eval_runtime": 417.46, + "eval_samples_per_second": 48.498, + "eval_steps_per_second": 3.033, + "eval_wer": 0.2648420574463624, + "step": 53000 + }, + { + "epoch": 38.08, + "learning_rate": 0.00036990960989533776, + "loss": 0.9795, + "step": 54000 + }, + { + "epoch": 38.08, + "eval_loss": 0.4228881299495697, + "eval_runtime": 417.6821, + "eval_samples_per_second": 48.472, + "eval_steps_per_second": 3.031, + "eval_wer": 0.25547732196519674, + "step": 54000 + }, + { + "epoch": 38.79, + "eval_loss": 0.4176000952720642, + "eval_runtime": 418.3393, + "eval_samples_per_second": 48.396, + "eval_steps_per_second": 3.026, + "eval_wer": 0.2545513313299322, + "step": 55000 + }, + { + "epoch": 39.49, + "learning_rate": 0.00034613463368220746, + "loss": 0.9644, + "step": 56000 + }, + { + "epoch": 39.49, + "eval_loss": 0.4222715497016907, + "eval_runtime": 423.2092, + "eval_samples_per_second": 47.839, + "eval_steps_per_second": 2.991, + "eval_wer": 0.25131036410650637, + "step": 56000 + }, + { + "epoch": 40.2, + "eval_loss": 0.4243711531162262, + "eval_runtime": 419.2699, + "eval_samples_per_second": 48.289, + "eval_steps_per_second": 3.02, + "eval_wer": 0.2530138374449647, + "step": 57000 + }, + { + "epoch": 40.9, + "learning_rate": 0.0003223477640342531, + "loss": 0.9534, + "step": 58000 + }, + { + "epoch": 40.9, + "eval_loss": 0.4174785912036896, + "eval_runtime": 417.6661, + "eval_samples_per_second": 48.474, + "eval_steps_per_second": 3.031, + "eval_wer": 0.2538349989517087, + "step": 58000 + }, + { + "epoch": 41.61, + "eval_loss": 0.4212724566459656, + "eval_runtime": 425.7723, + "eval_samples_per_second": 47.551, + "eval_steps_per_second": 2.973, + "eval_wer": 0.2505416171640226, + "step": 59000 + }, + { + "epoch": 42.31, + "learning_rate": 0.00029856089438629876, + "loss": 0.9397, + "step": 60000 + }, + { + "epoch": 42.31, + "eval_loss": 0.4275393486022949, + "eval_runtime": 419.6181, + "eval_samples_per_second": 48.249, + "eval_steps_per_second": 3.017, + "eval_wer": 0.2565343490111119, + "step": 60000 + }, + { + "epoch": 43.02, + "eval_loss": 0.4315040111541748, + "eval_runtime": 420.9924, + "eval_samples_per_second": 48.091, + "eval_steps_per_second": 3.007, + "eval_wer": 0.2528129149486337, + "step": 61000 + }, + { + "epoch": 43.72, + "learning_rate": 0.00027479781160799237, + "loss": 0.9269, + "step": 62000 + }, + { + "epoch": 43.72, + "eval_loss": 0.4316493570804596, + "eval_runtime": 417.0076, + "eval_samples_per_second": 48.551, + "eval_steps_per_second": 3.036, + "eval_wer": 0.2501048291285205, + "step": 62000 + }, + { + "epoch": 44.43, + "eval_loss": 0.42470675706863403, + "eval_runtime": 418.3345, + "eval_samples_per_second": 48.397, + "eval_steps_per_second": 3.026, + "eval_wer": 0.24705604864071562, + "step": 63000 + }, + { + "epoch": 45.13, + "learning_rate": 0.00025101094196003804, + "loss": 0.9175, + "step": 64000 + }, + { + "epoch": 45.13, + "eval_loss": 0.43763282895088196, + "eval_runtime": 418.9083, + "eval_samples_per_second": 48.33, + "eval_steps_per_second": 3.022, + "eval_wer": 0.24685512614438465, + "step": 64000 + }, + { + "epoch": 45.84, + "eval_loss": 0.4334784150123596, + "eval_runtime": 418.8375, + "eval_samples_per_second": 48.339, + "eval_steps_per_second": 3.023, + "eval_wer": 0.24501188063456567, + "step": 65000 + }, + { + "epoch": 46.54, + "learning_rate": 0.00022724785918173168, + "loss": 0.9026, + "step": 66000 + }, + { + "epoch": 46.54, + "eval_loss": 0.4336349070072174, + "eval_runtime": 419.7669, + "eval_samples_per_second": 48.232, + "eval_steps_per_second": 3.016, + "eval_wer": 0.24517786008805648, + "step": 66000 + }, + { + "epoch": 47.25, + "eval_loss": 0.4399877190589905, + "eval_runtime": 419.3411, + "eval_samples_per_second": 48.281, + "eval_steps_per_second": 3.019, + "eval_wer": 0.24265322524285415, + "step": 67000 + }, + { + "epoch": 47.95, + "learning_rate": 0.00020346098953377738, + "loss": 0.8929, + "step": 68000 + }, + { + "epoch": 47.95, + "eval_loss": 0.43824830651283264, + "eval_runtime": 419.5549, + "eval_samples_per_second": 48.256, + "eval_steps_per_second": 3.017, + "eval_wer": 0.24285414773918512, + "step": 68000 + }, + { + "epoch": 48.66, + "eval_loss": 0.43613117933273315, + "eval_runtime": 418.6436, + "eval_samples_per_second": 48.361, + "eval_steps_per_second": 3.024, + "eval_wer": 0.24154378363267873, + "step": 69000 + }, + { + "epoch": 49.37, + "learning_rate": 0.00017967411988582303, + "loss": 0.8786, + "step": 70000 + }, + { + "epoch": 49.37, + "eval_loss": 0.44130945205688477, + "eval_runtime": 418.5492, + "eval_samples_per_second": 48.372, + "eval_steps_per_second": 3.025, + "eval_wer": 0.23977915996925012, + "step": 70000 + }, + { + "epoch": 50.07, + "eval_loss": 0.43924885988235474, + "eval_runtime": 419.6699, + "eval_samples_per_second": 48.243, + "eval_steps_per_second": 3.017, + "eval_wer": 0.2415001048291285, + "step": 71000 + }, + { + "epoch": 50.78, + "learning_rate": 0.00015591103710751666, + "loss": 0.8714, + "step": 72000 + }, + { + "epoch": 50.78, + "eval_loss": 0.4345008134841919, + "eval_runtime": 418.4339, + "eval_samples_per_second": 48.385, + "eval_steps_per_second": 3.026, + "eval_wer": 0.24062652875812426, + "step": 72000 + }, + { + "epoch": 51.48, + "eval_loss": 0.44752031564712524, + "eval_runtime": 417.7023, + "eval_samples_per_second": 48.47, + "eval_steps_per_second": 3.031, + "eval_wer": 0.24017226920120205, + "step": 73000 + }, + { + "epoch": 52.19, + "learning_rate": 0.00013212416745956231, + "loss": 0.8589, + "step": 74000 + }, + { + "epoch": 52.19, + "eval_loss": 0.4473401606082916, + "eval_runtime": 419.491, + "eval_samples_per_second": 48.263, + "eval_steps_per_second": 3.018, + "eval_wer": 0.23740303305611854, + "step": 74000 + }, + { + "epoch": 52.89, + "eval_loss": 0.4457215368747711, + "eval_runtime": 416.4935, + "eval_samples_per_second": 48.611, + "eval_steps_per_second": 3.04, + "eval_wer": 0.23568208819624012, + "step": 75000 + }, + { + "epoch": 53.6, + "learning_rate": 0.00010834919124643197, + "loss": 0.8493, + "step": 76000 + }, + { + "epoch": 53.6, + "eval_loss": 0.44615164399147034, + "eval_runtime": 416.8896, + "eval_samples_per_second": 48.564, + "eval_steps_per_second": 3.037, + "eval_wer": 0.2365556642672444, + "step": 76000 + }, + { + "epoch": 54.3, + "eval_loss": 0.44939640164375305, + "eval_runtime": 416.8835, + "eval_samples_per_second": 48.565, + "eval_steps_per_second": 3.037, + "eval_wer": 0.2355947305891397, + "step": 77000 + }, + { + "epoch": 55.01, + "learning_rate": 8.456232159847763e-05, + "loss": 0.8395, + "step": 78000 + }, + { + "epoch": 55.01, + "eval_loss": 0.44722679257392883, + "eval_runtime": 419.1529, + "eval_samples_per_second": 48.302, + "eval_steps_per_second": 3.02, + "eval_wer": 0.23519288559647775, + "step": 78000 + }, + { + "epoch": 55.71, + "eval_loss": 0.44897979497909546, + "eval_runtime": 419.2062, + "eval_samples_per_second": 48.296, + "eval_steps_per_second": 3.02, + "eval_wer": 0.23388252148997135, + "step": 79000 + }, + { + "epoch": 56.42, + "learning_rate": 6.0799238820171265e-05, + "loss": 0.8295, + "step": 80000 + }, + { + "epoch": 56.42, + "eval_loss": 0.4489339590072632, + "eval_runtime": 420.6011, + "eval_samples_per_second": 48.136, + "eval_steps_per_second": 3.01, + "eval_wer": 0.23176846739814103, + "step": 80000 + }, + { + "epoch": 57.12, + "eval_loss": 0.4468826949596405, + "eval_runtime": 417.4618, + "eval_samples_per_second": 48.498, + "eval_steps_per_second": 3.033, + "eval_wer": 0.23203927598015237, + "step": 81000 + }, + { + "epoch": 57.83, + "learning_rate": 3.7012369172216936e-05, + "loss": 0.8225, + "step": 82000 + }, + { + "epoch": 57.83, + "eval_loss": 0.4478228688240051, + "eval_runtime": 417.4954, + "eval_samples_per_second": 48.494, + "eval_steps_per_second": 3.032, + "eval_wer": 0.23214410510867287, + "step": 82000 + }, + { + "epoch": 58.53, + "eval_loss": 0.4525238573551178, + "eval_runtime": 415.49, + "eval_samples_per_second": 48.728, + "eval_steps_per_second": 3.047, + "eval_wer": 0.2326071004263051, + "step": 83000 + }, + { + "epoch": 59.24, + "learning_rate": 1.3225499524262608e-05, + "loss": 0.816, + "step": 84000 + }, + { + "epoch": 59.24, + "eval_loss": 0.4532177150249481, + "eval_runtime": 419.041, + "eval_samples_per_second": 48.315, + "eval_steps_per_second": 3.021, + "eval_wer": 0.2315588091411, + "step": 84000 + }, + { + "epoch": 59.94, + "eval_loss": 0.45018914341926575, + "eval_runtime": 418.7584, + "eval_samples_per_second": 48.348, + "eval_steps_per_second": 3.023, + "eval_wer": 0.23179467468027115, + "step": 85000 + }, + { + "epoch": 60.0, + "step": 85080, + "total_flos": 4.6401496923493315e+20, + "train_loss": 1.0617500136590194, + "train_runtime": 191647.3893, + "train_samples_per_second": 28.42, + "train_steps_per_second": 0.444 + } + ], + "max_steps": 85080, + "num_train_epochs": 60, + "total_flos": 4.6401496923493315e+20, + "trial_name": null, + "trial_params": null +} diff --git a/training_args.bin b/training_args.bin new file mode 100644 index 0000000..5c8a844 --- /dev/null +++ b/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dfeb10e403369a185fe8d387f3aab04ec2ddef0dea3b5d4d22baea6101fdec23 +size 3055 diff --git a/validation_results.json b/validation_results.json new file mode 100644 index 0000000..0a84c45 --- /dev/null +++ b/validation_results.json @@ -0,0 +1,9 @@ +{ + "epoch": 60.0, + "eval_loss": 0.4502493739128113, + "eval_runtime": 418.5824, + "eval_samples": 20246, + "eval_samples_per_second": 48.368, + "eval_steps_per_second": 3.024, + "eval_wer": 0.23182961772311134 +} \ No newline at end of file diff --git a/vocab.json b/vocab.json new file mode 100644 index 0000000..eb991cf --- /dev/null +++ b/vocab.json @@ -0,0 +1 @@ +{"א": 1, "ב": 2, "ג": 3, "ד": 4, "ה": 5, "ו": 6, "ז": 7, "ח": 8, "ט": 9, "י": 10, "ך": 11, "כ": 12, "ל": 13, "ם": 14, "מ": 15, "ן": 16, "נ": 17, "ס": 18, "ע": 19, "ף": 20, "פ": 21, "ץ": 22, "צ": 23, "ק": 24, "ר": 25, "ש": 26, "ת": 27, "|": 0, "[UNK]": 28, "[PAD]": 29} \ No newline at end of file