Support distil-small.en whisper (#472)

This commit is contained in:
Fangjun Kuang
2023-12-08 11:59:20 +08:00
committed by GitHub
parent 3ae984f148
commit 868c339e5e
7 changed files with 84 additions and 24 deletions

View File

@@ -44,7 +44,7 @@ def get_args():
"tiny", "tiny.en", "base", "base.en",
"small", "small.en", "medium", "medium.en",
"large", "large-v1", "large-v2",
"distil-medium.en",
"distil-medium.en", "distil-small.en", "distil-large-v2"
],
# fmt: on
)
@@ -314,6 +314,32 @@ def main():
"""
)
model = whisper.load_model(filename)
elif name == "distil-large-v2":
filename = "./distil-large-v2-original-model.bin"
if not Path(filename).is_file():
raise ValueError(
"""
Please go to https://huggingface.co/distil-whisper/distil-large-v2
to download original-model.bin
You can use the following command to do that:
wget -O distil-large-v2-original-model.bin https://huggingface.co/distil-whisper/distil-large-v2/resolve/main/original-model.bin
"""
)
model = whisper.load_model(filename)
elif name == "distil-small.en":
filename = "./distil-small-en-original-model.bin"
if not Path(filename).is_file():
raise ValueError(
"""
Please go to https://huggingface.co/distil-whisper/distil-small.en
to download original-model.bin
You can use the following command to do that:
wget -O distil-small-en-original-model.bin https://huggingface.co/distil-whisper/distil-small.en/resolve/main/original-model.bin
"""
)
model = whisper.load_model(filename)
else:
model = whisper.load_model(name)
print(model.dims)

View File

@@ -209,7 +209,7 @@ class OnnxModel:
logits = logits.reshape(-1)
mask = torch.ones(logits.shape[0], dtype=torch.int64)
mask[self.all_language_tokens] = 0
logits[mask] = float("-inf")
logits[mask != 0] = float("-inf")
lang_id = logits.argmax().item()
print("detected language: ", self.id2lang[lang_id])
return lang_id
@@ -263,7 +263,9 @@ def compute_features(filename: str) -> torch.Tensor:
target = 3000
if mel.shape[0] > target:
mel = mel[:target]
# -50 so that there are some zero tail paddings.
mel = mel[: target - 50]
mel = torch.nn.functional.pad(mel, (0, 0, 0, 50), "constant", 0)
# We don't need to pad it to 30 seconds now!
# mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0)