初始化项目,由ModelHub XC社区提供模型

Model: kotoba-tech/kotoba-whisper-bilingual-v1.0
Source: Original Platform
This commit is contained in:
ModelHub XC
2026-05-15 08:05:30 +08:00
commit cb7d038c06
23 changed files with 234567 additions and 0 deletions

35
.gitattributes vendored Normal file
View File

@@ -0,0 +1,35 @@
*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.bz2 filter=lfs diff=lfs merge=lfs -text
*.ckpt filter=lfs diff=lfs merge=lfs -text
*.ftz filter=lfs diff=lfs merge=lfs -text
*.gz filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.joblib filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.mlmodel filter=lfs diff=lfs merge=lfs -text
*.model filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
*.npz filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
*.pkl filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.rar filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
*.tar filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
*.wasm filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text

208
README.md Normal file
View File

@@ -0,0 +1,208 @@
---
license: apache-2.0
datasets:
- japanese-asr/en_asr.mls
- japanese-asr/ja_asr.reazon_speech_all
language:
- en
- ja
pipeline_tag: automatic-speech-recognition
library_name: transformers
tags:
- audio
- automatic-speech-recognition
- hf-asr-leaderboard
---
# Kotoba-Whisper-Bilingual (v1.0)
[**faster-whisper weight**](https://huggingface.co/kotoba-tech/kotoba-whisper-bilingual-v1.0-faster), [**whisper.cpp weight**](https://huggingface.co/kotoba-tech/kotoba-whisper-bilingual-v1.0-ggml)
_Kotoba-Whisper-Bilingual_ is a collection of distilled [Whisper](https://arxiv.org/abs/2212.04356) models trained for
- **Japanese ASR**
- **English ASR**
- **Speech-to-text translation (Japanese -> English)**
- **Speech-to-text translation (English -> Japanese)**
developed through the collaboration bewteen
[Asahi Ushio](https://asahiushio.com) and [Kotoba Technologies](https://twitter.com/kotoba_tech).
Following the original work of distil-whisper ([Robust Knowledge Distillation via Large-Scale Pseudo Labelling](https://arxiv.org/abs/2311.00430)),
we employ OpenAI's [Whisper large-v3](https://huggingface.co/openai/whisper-large-v3) as the teacher model for Japanese and English ASR, while we translate the
transcription into English and Japanese by external LLM to obtain training dataset for speech-to-text translation.
We employ [ReazonSpeech](https://huggingface.co/datasets/japanese-asr/ja_asr.reazon_speech_all) for Japanese ASR and Japanese speech to English text translation,
and [Multilingual LibriSpeech](https://huggingface.co/datasets/japanese-asr/en_asr.mls) for English ASR and English speech to Japanese text translation.
Kotoba-whisper-bilingual's loss objective consists of cross-entropy on both of ASR and translation tasks, while KL divergence loss only for ASR task.
The student model consists the full encoder of the teacher large-v3 model and the decoder with two layers initialized from the first and last layer of the large-v3 model.
As kotoba-whisper uses the same architecture as [distil-whisper/distil-large-v3](https://huggingface.co/distil-whisper/distil-large-v3),
it inherits the benefit of the improved latency compared to [openai/whisper-large-v3](https://huggingface.co/openai/whisper-large-v3)
(**6.3x faster than large-v3**, see the table below taken from [distil-whisper/distil-large-v3](https://huggingface.co/distil-whisper/distil-large-v3)).
## Evaluation
We compare our kotoba-whisper-bilingual with OpenAI whisper models, kotoba-whisper models, and cascaded models for translation.
**Worth noting that kotoba-whisper-bilingual is the only model that can do Japanese and English ASR and speech-to-text translation between Japanese and English**, as
OpenAI whisper is not trained for English to Japanese speech-to-text translation, and other models are specific to the Task (eg. kotoba-whisper is Japanese ASR and
distil whisper is English ASR only).
### Speech2Text Translation (Japanese->English): WER (smaller is better)
| model | [CoVoST2 (Ja->En)](https://huggingface.co/datasets/japanese-asr/ja2en.s2t_translation)| [Fleurs (Ja->En)](https://huggingface.co/datasets/japanese-asr/ja2en.s2t_translation) |
|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------:|------------------------------------------------------------------------------------------------------:|
| [**kotoba-tech/kotoba-whisper-bilingual-v1.0**](https://huggingface.co/kotoba-tech/kotoba-whisper-bilingual-v1.0) | 73.9 | 98.7 |
| [japanese-asr/ja-cascaded-s2t-translation](https://huggingface.co/japanese-asr/ja-cascaded-s2t-translation) ([facebook/nllb-200-3.3B](https://huggingface.co/facebook/nllb-200-3.3B)) | 64.3 | 67.1 |
| [japanese-asr/ja-cascaded-s2t-translation](https://huggingface.co/japanese-asr/ja-cascaded-s2t-translation) ([facebook/nllb-200-1.3B](https://huggingface.co/facebook/nllb-200-1.3B)) | 65.4 | 68.9 |
| [japanese-asr/ja-cascaded-s2t-translation](https://huggingface.co/japanese-asr/ja-cascaded-s2t-translation) ([facebook/nllb-200-distilled-1.3B](https://huggingface.co/facebook/nllb-200-distilled-1.3B)) | 65.6 | 67.4 |
| [japanese-asr/ja-cascaded-s2t-translation](https://huggingface.co/japanese-asr/ja-cascaded-s2t-translation) ([facebook/nllb-200-distilled-600M](https://huggingface.co/facebook/nllb-200-distilled-600M)) | 68.2 | 72.2 |
| [openai/whisper-large-v3](https://huggingface.co/openai/whisper-large-v3) | 71 | 86.1 |
| [openai/whisper-large-v2](https://huggingface.co/openai/whisper-large-v2) | 66.4 | 78.8 |
| [openai/whisper-large](https://huggingface.co/openai/whisper-large) | 66.5 | 86.1 |
| [openai/whisper-medium](https://huggingface.co/openai/whisper-medium) | 70.3 | 97.2 |
| [openai/whisper-small](https://huggingface.co/openai/whisper-small) | 97.3 | 132.2 |
| [openai/whisper-base](https://huggingface.co/openai/whisper-base) | 186.2 | 349.6 |
| [openai/whisper-tiny](https://huggingface.co/openai/whisper-tiny) | 377.2 | 474 |
### Speech2Text Translation (English->Japanese): CER (smaller is better)
| model | [CoVoST2 (En->Ja)](https://huggingface.co/datasets/japanese-asr/en2ja.s2t_translation)| [Fleurs (En->JA)](https://huggingface.co/datasets/japanese-asr/en2ja.s2t_translation) |
|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------:|------------------------------------------------------------------------------------------------------:|
| [**kotoba-tech/kotoba-whisper-bilingual-v1.0**](https://huggingface.co/kotoba-tech/kotoba-whisper-bilingual-v1.0) | 69.1 | 74.4 |
| [japanese-asr/en-cascaded-s2t-translation](https://huggingface.co/japanese-asr/en-cascaded-s2t-translation) ([facebook/nllb-200-3.3B](https://huggingface.co/facebook/nllb-200-3.3B)) | 62.4 | 63.5 |
| [japanese-asr/en-cascaded-s2t-translation](https://huggingface.co/japanese-asr/en-cascaded-s2t-translation) ([facebook/nllb-200-1.3B](https://huggingface.co/facebook/nllb-200-1.3B)) | 64.4 | 67.2 |
| [japanese-asr/en-cascaded-s2t-translation](https://huggingface.co/japanese-asr/en-cascaded-s2t-translation) ([facebook/nllb-200-distilled-1.3B](https://huggingface.co/facebook/nllb-200-distilled-1.3B)) | 62.4 | 62.9 |
| [japanese-asr/en-cascaded-s2t-translation](https://huggingface.co/japanese-asr/en-cascaded-s2t-translation) ([facebook/nllb-200-distilled-600M](https://huggingface.co/facebook/nllb-200-distilled-600M)) | 63.4 | 66.2 |
| [openai/whisper-large-v3](https://huggingface.co/openai/whisper-large-v3) | 178.9 | 209.5 |
| [openai/whisper-large-v2](https://huggingface.co/openai/whisper-large-v2) | 179.6 | 201.8 |
| [openai/whisper-large](https://huggingface.co/openai/whisper-large) | 178.7 | 201.8 |
| [openai/whisper-medium](https://huggingface.co/openai/whisper-medium) | 178.7 | 202 |
| [openai/whisper-small](https://huggingface.co/openai/whisper-small) | 178.9 | 206.8 |
| [openai/whisper-base](https://huggingface.co/openai/whisper-base) | 179.5 | 214.2 |
| [openai/whisper-tiny](https://huggingface.co/openai/whisper-tiny) | 185.2 | 200.5 |
### ASR (Japanese): CER (smaller is better)
| model | [CommonVoice 8 (Japanese test set)](https://huggingface.co/datasets/japanese-asr/ja_asr.common_voice_8_0) | [JSUT Basic 5000](https://huggingface.co/datasets/japanese-asr/ja_asr.jsut_basic5000) | [ReazonSpeech (held out test set)](https://huggingface.co/datasets/japanese-asr/ja_asr.reazonspeech_test) |
|:--------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------:|----------------------------------------------------------------------------------------:|------------------------------------------------------------------------------------------------------------:|
| [**kotoba-tech/kotoba-whisper-bilingual-v1.0**](https://huggingface.co/kotoba-tech/kotoba-whisper-bilingual-v1.0) | 9.8 | 9.3 | 16.8 |
| [kotoba-tech/kotoba-whisper-v2.0](https://huggingface.co/kotoba-tech/kotoba-whisper-v2.0) | 9.2 | 8.4 | 11.6 |
| [kotoba-tech/kotoba-whisper-v1.0](https://huggingface.co/kotoba-tech/kotoba-whisper-v1.0) | 9.4 | 8.5 | 12.2 |
| [openai/whisper-large-v3](https://huggingface.co/openai/whisper-large-v3) | 8.5 | 7.1 | 14.9 |
| [openai/whisper-large-v2](https://huggingface.co/openai/whisper-large-v2) | 9.7 | 8.2 | 28.1 |
| [openai/whisper-large](https://huggingface.co/openai/whisper-large) | 10 | 8.9 | 34.1 |
| [openai/whisper-medium](https://huggingface.co/openai/whisper-medium) | 11.5 | 10 | 33.2 |
| [openai/whisper-small](https://huggingface.co/openai/whisper-small) | 15.1 | 14.2 | 41.5 |
| [openai/whisper-base](https://huggingface.co/openai/whisper-base) | 28.6 | 24.9 | 70.4 |
| [openai/whisper-tiny](https://huggingface.co/openai/whisper-tiny) | 53.7 | 36.5 | 137.9 |
| [reazon-research/reazonspeech-nemo-v2](https://huggingface.co/reazon-research/reazonspeech-nemo-v2) | 9.1 | 7.4 | 11.2 |
### ASR (English): WER (smaller is better)
| model | [ESB](https://huggingface.co/datasets/japanese-asr/en_asr.esb_eval) (ami) | [ESB](https://huggingface.co/datasets/japanese-asr/en_asr.esb_eval) (earnings22) | [ESB](https://huggingface.co/datasets/japanese-asr/en_asr.esb_eval) (librispeech) | [ESB](https://huggingface.co/datasets/japanese-asr/en_asr.esb_eval) (tedlium) | [ESB](https://huggingface.co/datasets/japanese-asr/en_asr.esb_eval) (voxpopuli) |
|:----------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------:|-----------------------------------------------------------------------------------:|------------------------------------------------------------------------------------:|--------------------------------------------------------------------------------:|----------------------------------------------------------------------------------:|
| [**kotoba-tech/kotoba-whisper-bilingual-v1.0**](https://huggingface.co/kotoba-tech/kotoba-whisper-bilingual-v1.0) | 16.7 | 15.3 | 2.4 | 4.1 | 8.3 |
| [openai/whisper-large-v3](https://huggingface.co/openai/whisper-large-v3) | 17.9 | 14.9 | 2.1 | 3.8 | 12.7 |
| [openai/whisper-large-v2](https://huggingface.co/openai/whisper-large-v2) | 18.9 | 16.7 | 2.3 | 4.9 | 7.7 |
| [openai/whisper-large](https://huggingface.co/openai/whisper-large) | 18.8 | 14.9 | 2.6 | 4.2 | 7.7 |
| [openai/whisper-medium](https://huggingface.co/openai/whisper-medium) | 18.3 | 14.9 | 2.5 | 4.3 | 7.9 |
| [openai/whisper-small](https://huggingface.co/openai/whisper-small) | 23.1 | 17.2 | 3.5 | 5.3 | 10.8 |
| [openai/whisper-base](https://huggingface.co/openai/whisper-base) | 26.6 | 21 | 6 | 6.1 | 11.3 |
| [openai/whisper-tiny](https://huggingface.co/openai/whisper-tiny) | 31.9 | 30.5 | 8.2 | 11.7 | 15.1 |
| [japanese-asr/distil-whisper-bilingual-v1.0](https://huggingface.co/japanese-asr/distil-whisper-bilingual-v1.0) | 20.7 | 18.6 | 2.4 | 6.4 | 10 |
### Inference Speed
Although the cascaded approach is better in translation task, due to the nature of cascaded approach, the pipeline
has additional complexity and memory consumption compared to the single end2end models for the sake of high accuracy.
Following table shows the mean inference time on a single RTX 4090 (VRAM 24 GB) in second averaged over 10 trials on audio sample with different durations, along with the parameter size.
| model | Param. (M) | 10 (sec.) | 30 (sec.) | 60 (sec.) | 300 (sec.) |
|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------:|------:|------:|------:|------:|
| [**kotoba-tech/kotoba-whisper-bilingual-v1.0**](https://huggingface.co/kotoba-tech/kotoba-whisper-bilingual-v1.0) | 756 | 0.041 | 0.111 | 0.214 | 1.077 |
| [japanese-asr/en-cascaded-s2t-translation](https://huggingface.co/japanese-asr/en-cascaded-s2t-translation) ([facebook/nllb-200-3.3B](https://huggingface.co/facebook/nllb-200-3.3B)) | 4056 | 0.173 | 0.247 | 0.352 | 1.772 |
| [japanese-asr/en-cascaded-s2t-translation](https://huggingface.co/japanese-asr/en-cascaded-s2t-translation) ([facebook/nllb-200-1.3B](https://huggingface.co/facebook/nllb-200-1.3B)) | 2056 | 0.173 | 0.24 | 0.348 | 1.515 |
| [japanese-asr/en-cascaded-s2t-translation](https://huggingface.co/japanese-asr/en-cascaded-s2t-translation) ([facebook/nllb-200-distilled-1.3B](https://huggingface.co/facebook/nllb-200-distilled-1.3B)) | 2056 | 0.17 | 0.245 | 0.348 | 1.882 |
| [japanese-asr/en-cascaded-s2t-translation](https://huggingface.co/japanese-asr/en-cascaded-s2t-translation) ([facebook/nllb-200-distilled-600M](https://huggingface.co/facebook/nllb-200-distilled-600M)) | 1256 | 0.108 | 0.179 | 0.283 | 1.33 |
## Transformers Usage
Kotoba-Whisper is supported in the Hugging Face 🤗 Transformers library from version 4.39 onwards. To run the model, first
install the latest version of Transformers.
```bash
pip install --upgrade pip
pip install --upgrade transformers accelerate
```
The model can be used with the [`pipeline`](https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline)
class to transcribe short-form audio files (< 30-seconds) as follows:
Download sample audio.
```shell
wget https://huggingface.co/datasets/japanese-asr/en_asr.esb_eval/resolve/main/sample.wav -O sample_en.wav
wget https://huggingface.co/datasets/japanese-asr/ja_asr.jsut_basic5000/resolve/main/sample.flac -O sample_ja.flac
```
```python
import torch
from transformers import pipeline
from datasets import load_dataset
# config
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model_kwargs = {"attn_implementation": "sdpa"} if torch.cuda.is_available() else {}
pipe = pipeline(
"automatic-speech-recognition",
model="kotoba-tech/kotoba-whisper-bilingual-v1.0",
torch_dtype=torch_dtype,
device=device,
model_kwargs=model_kwargs,
chunk_length_s=15,
batch_size=16
)
# Japanese ASR
generate_kwargs = {"language": "ja", "task": "transcribe"}
result = pipe("sample_ja.flac", generate_kwargs=generate_kwargs)
print(result["text"])
# English ASR
generate_kwargs = {"language": "en", "task": "transcribe"}
result = pipe("sample_en.wav", generate_kwargs=generate_kwargs)
print(result["text"])
# Translate Japanese speech to English text
generate_kwargs = {"language": "en", "task": "translate"}
result = pipe("sample_ja.flac", generate_kwargs=generate_kwargs)
print(result["text"])
# Translate English speech to Japanese text
generate_kwargs = {"language": "ja", "task": "translate"}
result = pipe("sample_en.wav", generate_kwargs=generate_kwargs)
print(result["text"])
```
- For segment-level timestamps, pass the argument `return_timestamps=True` and return the `"chunks"` output:
```python
result = pipe(sample, return_timestamps=True, generate_kwargs=generate_kwargs)
print(result["chunks"])
```
## Training
Please refer to [https://github.com/kotoba-tech/kotoba-whisper](https://github.com/kotoba-tech/kotoba-whisper) for the model training detail.
Datasets used in distillation and the whole model variations can be found at [https://huggingface.co/japanese-asr](https://huggingface.co/japanese-asr).
## Acknowledgements
* [OpenAI](https://openai.com/) for the Whisper [model](https://huggingface.co/openai/whisper-large-v3).
* Hugging Face 🤗 [Transformers](https://github.com/huggingface/transformers) for the model integration.
* Hugging Face 🤗 for the [Distil-Whisper codebase](https://github.com/huggingface/distil-whisper).
* [Reazon Human Interaction Lab](https://research.reazon.jp/) for the [ReazonSpeech dataset](https://huggingface.co/datasets/reazon-research/reazonspeech).

1611
added_tokens.json Normal file

File diff suppressed because it is too large Load Diff

50
config.json Normal file
View File

@@ -0,0 +1,50 @@
{
"_name_or_path": "distil-whisper-bilingual",
"activation_dropout": 0.0,
"activation_function": "gelu",
"apply_spec_augment": false,
"architectures": [
"WhisperForConditionalGeneration"
],
"attention_dropout": 0.0,
"begin_suppress_tokens": [
220,
50257
],
"bos_token_id": 50257,
"classifier_proj_size": 256,
"d_model": 1280,
"decoder_attention_heads": 20,
"decoder_ffn_dim": 5120,
"decoder_layerdrop": 0.0,
"decoder_layers": 2,
"decoder_start_token_id": 50258,
"dropout": 0.0,
"encoder_attention_heads": 20,
"encoder_ffn_dim": 5120,
"encoder_layerdrop": 0.0,
"encoder_layers": 32,
"eos_token_id": 50257,
"init_std": 0.02,
"is_encoder_decoder": true,
"mask_feature_length": 10,
"mask_feature_min_masks": 0,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_min_masks": 2,
"mask_time_prob": 0.05,
"max_length": 448,
"max_source_positions": 1500,
"max_target_positions": 448,
"median_filter_width": 7,
"model_type": "whisper",
"num_hidden_layers": 32,
"num_mel_bins": 128,
"pad_token_id": 50256,
"scale_embedding": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.44.2",
"use_cache": true,
"use_weighted_layer_sum": false,
"vocab_size": 51866
}

164
create_student_model.py Normal file
View File

@@ -0,0 +1,164 @@
"""Initialize a student Whisper model from a pre-trained teacher model for teacher-student distillation."""
import argparse
import copy
import logging
import os
import numpy as np
import torch
from transformers import GenerationConfig, WhisperForConditionalGeneration, WhisperProcessor
# https://stackoverflow.com/questions/71692354/facing-ssl-error-with-huggingface-pretrained-models
os.environ['CURL_CA_BUNDLE'] = ''
# disable warning message
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(
description="Initialise a student Whisper model from a teacher model, copying the relevant layer weights and adjusting the processor as necessary."
)
parser.add_argument(
"--teacher_checkpoint",
type=str,
required=True,
help="The HF Hub ID of the teacher checkpoint.",
)
parser.add_argument(
"--encoder_layers",
type=int,
default=None,
help="Number of encoder layers to use in the student model. Defaults to all layers from the teacher.",
)
parser.add_argument(
"--decoder_layers",
type=int,
default=2,
help="Number of decoder layers to use in the student model. Defaults to 2 layers.",
)
parser.add_argument(
"--save_dir",
type=str,
required=True,
help="Where to save the student weights and processor.",
)
args = parser.parse_args()
return args
def init_student_model_from_teacher(
teacher_checkpoint,
save_dir,
encoder_layers=None,
decoder_layers=2,
):
teacher_model = WhisperForConditionalGeneration.from_pretrained(
teacher_checkpoint,
low_cpu_mem_usage=True,
)
processor = WhisperProcessor.from_pretrained(teacher_checkpoint)
generation_config = GenerationConfig.from_pretrained(teacher_checkpoint)
teacher_config = teacher_model.config
teacher_encoder_layers = teacher_config.encoder_layers
teacher_decoder_layers = teacher_config.decoder_layers
student_config = copy.deepcopy(teacher_config)
student_config.update(
{
"encoder_layers": encoder_layers if encoder_layers is not None else teacher_encoder_layers,
"decoder_layers": decoder_layers,
}
)
encoder_mapping = np.linspace(0, teacher_encoder_layers - 1, student_config.encoder_layers, dtype=int)
encoder_mapping[-1] = teacher_encoder_layers - 1
encoder_map = {}
for student_layer, teacher_layer in enumerate(encoder_mapping):
encoder_map[teacher_layer] = student_layer
decoder_mapping = np.linspace(0, teacher_decoder_layers - 1, student_config.decoder_layers, dtype=int)
decoder_mapping[-1] = teacher_decoder_layers - 1
decoder_map = {}
for student_layer, teacher_layer in enumerate(decoder_mapping):
decoder_map[teacher_layer] = student_layer
# init the student params from the teacher model
student_model = WhisperForConditionalGeneration(student_config)
missing_keys, unexpected_keys = student_model.load_state_dict(teacher_model.state_dict(), strict=False)
if len(missing_keys) > 0:
raise RuntimeError(
"Error(s) in loading state_dict for WhisperForConditionalGeneration. \n"
f"Missing key(s) in state_dict: {missing_keys}"
)
if decoder_layers == teacher_decoder_layers:
decoder_keys = [key for key in unexpected_keys if "model.decoder.layers" in key]
if len(decoder_keys) > 0:
raise RuntimeError(
"Error(s) in loading state_dict for WhisperForConditionalGeneration. \n"
f"Unexpected key(s) in state_dict: {decoder_keys}"
)
if encoder_layers == teacher_encoder_layers:
encoder_keys = [key for key in unexpected_keys if "model.encoder.layers" in key]
if len(encoder_keys) > 0:
raise RuntimeError(
"Error(s) in loading state_dict for WhisperForConditionalGeneration. \n"
f"Unexpected key(s) in state_dict: {encoder_keys}"
)
for layer in range(teacher_decoder_layers):
if layer in decoder_map:
# re-introduce pre-defined layers from the teacher
student_model.model.decoder.layers[decoder_map[layer]].load_state_dict(
teacher_model.model.decoder.layers[layer].state_dict()
)
if encoder_layers is not None:
for layer in range(teacher_encoder_layers):
if layer in encoder_map:
# re-introduce pre-defined layers from the teacher
student_model.model.encoder.layers[encoder_map[layer]].load_state_dict(
teacher_model.model.encoder.layers[layer].state_dict()
)
# remove the teacher params and model
del teacher_model
# save the converted weights and model
student_model.save_pretrained(save_dir)
# we also need to correctly save the processor and generation config
processor.save_pretrained(save_dir)
generation_config.save_pretrained(save_dir)
# check we can do a forward pass with the saved model - first load the weights and processor
logger.info("Checking we can load the saved model...")
student_model = WhisperForConditionalGeneration.from_pretrained(save_dir, low_cpu_mem_usage=True)
processor = WhisperProcessor.from_pretrained(save_dir)
# define some random inputs
input_features = processor(np.ones(16000), sampling_rate=16000, return_tensors="pt").input_features
decoder_start_token_id = student_model.config.decoder_start_token_id
decoder_input_ids = torch.ones((input_features.shape[0], 1), dtype=torch.long) * decoder_start_token_id
# do a forward pass - outputs will be gibberish for the initialised model so we can't check them
# but we make can sure the model runs as expected
logger.info("Checking we can run the converted model forward...")
_ = student_model(input_features, decoder_input_ids=decoder_input_ids).logits
logger.info("Conversion successful!")
if __name__ == "__main__":
args = parse_args()
init_student_model_from_teacher(
teacher_checkpoint=args.teacher_checkpoint,
encoder_layers=args.encoder_layers,
decoder_layers=args.decoder_layers,
save_dir=args.save_dir,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,50 @@
{
"_name_or_path": "openai/whisper-large-v3",
"activation_dropout": 0.0,
"activation_function": "gelu",
"apply_spec_augment": false,
"architectures": [
"WhisperForConditionalGeneration"
],
"attention_dropout": 0.0,
"begin_suppress_tokens": [
220,
50257
],
"bos_token_id": 50257,
"classifier_proj_size": 256,
"d_model": 1280,
"decoder_attention_heads": 20,
"decoder_ffn_dim": 5120,
"decoder_layerdrop": 0.0,
"decoder_layers": 2,
"decoder_start_token_id": 50258,
"dropout": 0.0,
"encoder_attention_heads": 20,
"encoder_ffn_dim": 5120,
"encoder_layerdrop": 0.0,
"encoder_layers": 32,
"eos_token_id": 50257,
"init_std": 0.02,
"is_encoder_decoder": true,
"mask_feature_length": 10,
"mask_feature_min_masks": 0,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_min_masks": 2,
"mask_time_prob": 0.05,
"max_length": 448,
"max_source_positions": 1500,
"max_target_positions": 448,
"median_filter_width": 7,
"model_type": "whisper",
"num_hidden_layers": 32,
"num_mel_bins": 128,
"pad_token_id": 50256,
"scale_embedding": false,
"torch_dtype": "float32",
"transformers_version": "4.44.2",
"use_cache": true,
"use_weighted_layer_sum": false,
"vocab_size": 51866
}

View File

@@ -0,0 +1,265 @@
{
"alignment_heads": [
[
7,
0
],
[
10,
17
],
[
12,
18
],
[
13,
12
],
[
16,
1
],
[
17,
14
],
[
19,
11
],
[
21,
4
],
[
24,
1
],
[
25,
6
]
],
"begin_suppress_tokens": [
220,
50257
],
"bos_token_id": 50257,
"decoder_start_token_id": 50258,
"eos_token_id": 50257,
"forced_decoder_ids": [
[
1,
null
],
[
2,
50360
]
],
"is_multilingual": true,
"lang_to_id": {
"<|af|>": 50327,
"<|am|>": 50334,
"<|ar|>": 50272,
"<|as|>": 50350,
"<|az|>": 50304,
"<|ba|>": 50355,
"<|be|>": 50330,
"<|bg|>": 50292,
"<|bn|>": 50302,
"<|bo|>": 50347,
"<|br|>": 50309,
"<|bs|>": 50315,
"<|ca|>": 50270,
"<|cs|>": 50283,
"<|cy|>": 50297,
"<|da|>": 50285,
"<|de|>": 50261,
"<|el|>": 50281,
"<|en|>": 50259,
"<|es|>": 50262,
"<|et|>": 50307,
"<|eu|>": 50310,
"<|fa|>": 50300,
"<|fi|>": 50277,
"<|fo|>": 50338,
"<|fr|>": 50265,
"<|gl|>": 50319,
"<|gu|>": 50333,
"<|haw|>": 50352,
"<|ha|>": 50354,
"<|he|>": 50279,
"<|hi|>": 50276,
"<|hr|>": 50291,
"<|ht|>": 50339,
"<|hu|>": 50286,
"<|hy|>": 50312,
"<|id|>": 50275,
"<|is|>": 50311,
"<|it|>": 50274,
"<|ja|>": 50266,
"<|jw|>": 50356,
"<|ka|>": 50329,
"<|kk|>": 50316,
"<|km|>": 50323,
"<|kn|>": 50306,
"<|ko|>": 50264,
"<|la|>": 50294,
"<|lb|>": 50345,
"<|ln|>": 50353,
"<|lo|>": 50336,
"<|lt|>": 50293,
"<|lv|>": 50301,
"<|mg|>": 50349,
"<|mi|>": 50295,
"<|mk|>": 50308,
"<|ml|>": 50296,
"<|mn|>": 50314,
"<|mr|>": 50320,
"<|ms|>": 50282,
"<|mt|>": 50343,
"<|my|>": 50346,
"<|ne|>": 50313,
"<|nl|>": 50271,
"<|nn|>": 50342,
"<|no|>": 50288,
"<|oc|>": 50328,
"<|pa|>": 50321,
"<|pl|>": 50269,
"<|ps|>": 50340,
"<|pt|>": 50267,
"<|ro|>": 50284,
"<|ru|>": 50263,
"<|sa|>": 50344,
"<|sd|>": 50332,
"<|si|>": 50322,
"<|sk|>": 50298,
"<|sl|>": 50305,
"<|sn|>": 50324,
"<|so|>": 50326,
"<|sq|>": 50317,
"<|sr|>": 50303,
"<|su|>": 50357,
"<|sv|>": 50273,
"<|sw|>": 50318,
"<|ta|>": 50287,
"<|te|>": 50299,
"<|tg|>": 50331,
"<|th|>": 50289,
"<|tk|>": 50341,
"<|tl|>": 50348,
"<|tr|>": 50268,
"<|tt|>": 50351,
"<|uk|>": 50280,
"<|ur|>": 50290,
"<|uz|>": 50337,
"<|vi|>": 50278,
"<|yi|>": 50335,
"<|yo|>": 50325,
"<|yue|>": 50358,
"<|zh|>": 50260
},
"max_initial_timestamp_index": 50,
"max_length": 448,
"no_timestamps_token_id": 50364,
"pad_token_id": 50257,
"prev_sot_token_id": 50362,
"return_timestamps": false,
"suppress_tokens": [
1,
2,
7,
8,
9,
10,
14,
25,
26,
27,
28,
29,
31,
58,
59,
60,
61,
62,
63,
90,
91,
92,
93,
359,
503,
522,
542,
873,
893,
902,
918,
922,
931,
1350,
1853,
1982,
2460,
2627,
3246,
3253,
3268,
3536,
3846,
3961,
4183,
4667,
6585,
6647,
7273,
9061,
9383,
10428,
10929,
11938,
12033,
12331,
12562,
13793,
14157,
14635,
15265,
15618,
16553,
16604,
18362,
18956,
20075,
21675,
22520,
26130,
26161,
26435,
28279,
29464,
31650,
32302,
32470,
36865,
42863,
47425,
49870,
50254,
50258,
50359,
50360,
50361,
50362,
50363
],
"task_to_id": {
"transcribe": 50360,
"translate": 50359
},
"transformers_version": "4.44.2"
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2c5ef44f7f59126b7b66937cc81d3194eb310f9af8b08512bbd6bd55fb0cda9f
size 3025686376

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,14 @@
{
"chunk_length": 30,
"feature_extractor_type": "WhisperFeatureExtractor",
"feature_size": 128,
"hop_length": 160,
"n_fft": 400,
"n_samples": 480000,
"nb_max_frames": 3000,
"padding_side": "right",
"padding_value": 0.0,
"processor_class": "WhisperProcessor",
"return_attention_mask": false,
"sampling_rate": 16000
}

View File

@@ -0,0 +1,139 @@
{
"additional_special_tokens": [
"<|startoftranscript|>",
"<|en|>",
"<|zh|>",
"<|de|>",
"<|es|>",
"<|ru|>",
"<|ko|>",
"<|fr|>",
"<|ja|>",
"<|pt|>",
"<|tr|>",
"<|pl|>",
"<|ca|>",
"<|nl|>",
"<|ar|>",
"<|sv|>",
"<|it|>",
"<|id|>",
"<|hi|>",
"<|fi|>",
"<|vi|>",
"<|he|>",
"<|uk|>",
"<|el|>",
"<|ms|>",
"<|cs|>",
"<|ro|>",
"<|da|>",
"<|hu|>",
"<|ta|>",
"<|no|>",
"<|th|>",
"<|ur|>",
"<|hr|>",
"<|bg|>",
"<|lt|>",
"<|la|>",
"<|mi|>",
"<|ml|>",
"<|cy|>",
"<|sk|>",
"<|te|>",
"<|fa|>",
"<|lv|>",
"<|bn|>",
"<|sr|>",
"<|az|>",
"<|sl|>",
"<|kn|>",
"<|et|>",
"<|mk|>",
"<|br|>",
"<|eu|>",
"<|is|>",
"<|hy|>",
"<|ne|>",
"<|mn|>",
"<|bs|>",
"<|kk|>",
"<|sq|>",
"<|sw|>",
"<|gl|>",
"<|mr|>",
"<|pa|>",
"<|si|>",
"<|km|>",
"<|sn|>",
"<|yo|>",
"<|so|>",
"<|af|>",
"<|oc|>",
"<|ka|>",
"<|be|>",
"<|tg|>",
"<|sd|>",
"<|gu|>",
"<|am|>",
"<|yi|>",
"<|lo|>",
"<|uz|>",
"<|fo|>",
"<|ht|>",
"<|ps|>",
"<|tk|>",
"<|nn|>",
"<|mt|>",
"<|sa|>",
"<|lb|>",
"<|my|>",
"<|bo|>",
"<|tl|>",
"<|mg|>",
"<|as|>",
"<|tt|>",
"<|haw|>",
"<|ln|>",
"<|ha|>",
"<|ba|>",
"<|jw|>",
"<|su|>",
"<|yue|>",
"<|translate|>",
"<|transcribe|>",
"<|startoflm|>",
"<|startofprev|>",
"<|nospeech|>",
"<|notimestamps|>"
],
"bos_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

265
generation_config.json Normal file
View File

@@ -0,0 +1,265 @@
{
"alignment_heads": [
[
7,
0
],
[
10,
17
],
[
12,
18
],
[
13,
12
],
[
16,
1
],
[
17,
14
],
[
19,
11
],
[
21,
4
],
[
24,
1
],
[
25,
6
]
],
"begin_suppress_tokens": [
220,
50257
],
"bos_token_id": 50257,
"decoder_start_token_id": 50258,
"eos_token_id": 50257,
"forced_decoder_ids": [
[
1,
null
],
[
2,
50360
]
],
"is_multilingual": true,
"lang_to_id": {
"<|af|>": 50327,
"<|am|>": 50334,
"<|ar|>": 50272,
"<|as|>": 50350,
"<|az|>": 50304,
"<|ba|>": 50355,
"<|be|>": 50330,
"<|bg|>": 50292,
"<|bn|>": 50302,
"<|bo|>": 50347,
"<|br|>": 50309,
"<|bs|>": 50315,
"<|ca|>": 50270,
"<|cs|>": 50283,
"<|cy|>": 50297,
"<|da|>": 50285,
"<|de|>": 50261,
"<|el|>": 50281,
"<|en|>": 50259,
"<|es|>": 50262,
"<|et|>": 50307,
"<|eu|>": 50310,
"<|fa|>": 50300,
"<|fi|>": 50277,
"<|fo|>": 50338,
"<|fr|>": 50265,
"<|gl|>": 50319,
"<|gu|>": 50333,
"<|haw|>": 50352,
"<|ha|>": 50354,
"<|he|>": 50279,
"<|hi|>": 50276,
"<|hr|>": 50291,
"<|ht|>": 50339,
"<|hu|>": 50286,
"<|hy|>": 50312,
"<|id|>": 50275,
"<|is|>": 50311,
"<|it|>": 50274,
"<|ja|>": 50266,
"<|jw|>": 50356,
"<|ka|>": 50329,
"<|kk|>": 50316,
"<|km|>": 50323,
"<|kn|>": 50306,
"<|ko|>": 50264,
"<|la|>": 50294,
"<|lb|>": 50345,
"<|ln|>": 50353,
"<|lo|>": 50336,
"<|lt|>": 50293,
"<|lv|>": 50301,
"<|mg|>": 50349,
"<|mi|>": 50295,
"<|mk|>": 50308,
"<|ml|>": 50296,
"<|mn|>": 50314,
"<|mr|>": 50320,
"<|ms|>": 50282,
"<|mt|>": 50343,
"<|my|>": 50346,
"<|ne|>": 50313,
"<|nl|>": 50271,
"<|nn|>": 50342,
"<|no|>": 50288,
"<|oc|>": 50328,
"<|pa|>": 50321,
"<|pl|>": 50269,
"<|ps|>": 50340,
"<|pt|>": 50267,
"<|ro|>": 50284,
"<|ru|>": 50263,
"<|sa|>": 50344,
"<|sd|>": 50332,
"<|si|>": 50322,
"<|sk|>": 50298,
"<|sl|>": 50305,
"<|sn|>": 50324,
"<|so|>": 50326,
"<|sq|>": 50317,
"<|sr|>": 50303,
"<|su|>": 50357,
"<|sv|>": 50273,
"<|sw|>": 50318,
"<|ta|>": 50287,
"<|te|>": 50299,
"<|tg|>": 50331,
"<|th|>": 50289,
"<|tk|>": 50341,
"<|tl|>": 50348,
"<|tr|>": 50268,
"<|tt|>": 50351,
"<|uk|>": 50280,
"<|ur|>": 50290,
"<|uz|>": 50337,
"<|vi|>": 50278,
"<|yi|>": 50335,
"<|yo|>": 50325,
"<|yue|>": 50358,
"<|zh|>": 50260
},
"max_initial_timestamp_index": 50,
"max_length": 128,
"no_timestamps_token_id": 50364,
"pad_token_id": 50257,
"prev_sot_token_id": 50362,
"return_timestamps": false,
"suppress_tokens": [
1,
2,
7,
8,
9,
10,
14,
25,
26,
27,
28,
29,
31,
58,
59,
60,
61,
62,
63,
90,
91,
92,
93,
359,
503,
522,
542,
873,
893,
902,
918,
922,
931,
1350,
1853,
1982,
2460,
2627,
3246,
3253,
3268,
3536,
3846,
3961,
4183,
4667,
6585,
6647,
7273,
9061,
9383,
10428,
10929,
11938,
12033,
12331,
12562,
13793,
14157,
14635,
15265,
15618,
16553,
16604,
18362,
18956,
20075,
21675,
22520,
26130,
26161,
26435,
28279,
29464,
31650,
32302,
32470,
36865,
42863,
47425,
49870,
50254,
50258,
50359,
50360,
50361,
50362,
50363
],
"task_to_id": {
"transcribe": 50360,
"translate": 50359
},
"transformers_version": "4.44.2"
}

50001
merges.txt Normal file

File diff suppressed because it is too large Load Diff

3
model.safetensors Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b3a7c08e4dfe000c70015f0d538cbcc3773785cc5b817e7b48fde4182206800c
size 1512875008

1742
normalizer.json Normal file

File diff suppressed because it is too large Load Diff

14
preprocessor_config.json Normal file
View File

@@ -0,0 +1,14 @@
{
"chunk_length": 30,
"feature_extractor_type": "WhisperFeatureExtractor",
"feature_size": 128,
"hop_length": 160,
"n_fft": 400,
"n_samples": 480000,
"nb_max_frames": 3000,
"padding_side": "right",
"padding_value": 0.0,
"processor_class": "WhisperProcessor",
"return_attention_mask": false,
"sampling_rate": 16000
}

139
special_tokens_map.json Normal file
View File

@@ -0,0 +1,139 @@
{
"additional_special_tokens": [
"<|startoftranscript|>",
"<|en|>",
"<|zh|>",
"<|de|>",
"<|es|>",
"<|ru|>",
"<|ko|>",
"<|fr|>",
"<|ja|>",
"<|pt|>",
"<|tr|>",
"<|pl|>",
"<|ca|>",
"<|nl|>",
"<|ar|>",
"<|sv|>",
"<|it|>",
"<|id|>",
"<|hi|>",
"<|fi|>",
"<|vi|>",
"<|he|>",
"<|uk|>",
"<|el|>",
"<|ms|>",
"<|cs|>",
"<|ro|>",
"<|da|>",
"<|hu|>",
"<|ta|>",
"<|no|>",
"<|th|>",
"<|ur|>",
"<|hr|>",
"<|bg|>",
"<|lt|>",
"<|la|>",
"<|mi|>",
"<|ml|>",
"<|cy|>",
"<|sk|>",
"<|te|>",
"<|fa|>",
"<|lv|>",
"<|bn|>",
"<|sr|>",
"<|az|>",
"<|sl|>",
"<|kn|>",
"<|et|>",
"<|mk|>",
"<|br|>",
"<|eu|>",
"<|is|>",
"<|hy|>",
"<|ne|>",
"<|mn|>",
"<|bs|>",
"<|kk|>",
"<|sq|>",
"<|sw|>",
"<|gl|>",
"<|mr|>",
"<|pa|>",
"<|si|>",
"<|km|>",
"<|sn|>",
"<|yo|>",
"<|so|>",
"<|af|>",
"<|oc|>",
"<|ka|>",
"<|be|>",
"<|tg|>",
"<|sd|>",
"<|gu|>",
"<|am|>",
"<|yi|>",
"<|lo|>",
"<|uz|>",
"<|fo|>",
"<|ht|>",
"<|ps|>",
"<|tk|>",
"<|nn|>",
"<|mt|>",
"<|sa|>",
"<|lb|>",
"<|my|>",
"<|bo|>",
"<|tl|>",
"<|mg|>",
"<|as|>",
"<|tt|>",
"<|haw|>",
"<|ln|>",
"<|ha|>",
"<|ba|>",
"<|jw|>",
"<|su|>",
"<|yue|>",
"<|translate|>",
"<|transcribe|>",
"<|startoflm|>",
"<|startofprev|>",
"<|nospeech|>",
"<|notimestamps|>"
],
"bos_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}

12996
tokenizer_config.json Normal file

File diff suppressed because it is too large Load Diff

50259
vocab.json Normal file

File diff suppressed because it is too large Load Diff