Support distil-small.en whisper (#472)
This commit is contained in:
@@ -44,7 +44,7 @@ def get_args():
|
||||
"tiny", "tiny.en", "base", "base.en",
|
||||
"small", "small.en", "medium", "medium.en",
|
||||
"large", "large-v1", "large-v2",
|
||||
"distil-medium.en",
|
||||
"distil-medium.en", "distil-small.en", "distil-large-v2"
|
||||
],
|
||||
# fmt: on
|
||||
)
|
||||
@@ -314,6 +314,32 @@ def main():
|
||||
"""
|
||||
)
|
||||
model = whisper.load_model(filename)
|
||||
elif name == "distil-large-v2":
|
||||
filename = "./distil-large-v2-original-model.bin"
|
||||
if not Path(filename).is_file():
|
||||
raise ValueError(
|
||||
"""
|
||||
Please go to https://huggingface.co/distil-whisper/distil-large-v2
|
||||
to download original-model.bin
|
||||
You can use the following command to do that:
|
||||
|
||||
wget -O distil-large-v2-original-model.bin https://huggingface.co/distil-whisper/distil-large-v2/resolve/main/original-model.bin
|
||||
"""
|
||||
)
|
||||
model = whisper.load_model(filename)
|
||||
elif name == "distil-small.en":
|
||||
filename = "./distil-small-en-original-model.bin"
|
||||
if not Path(filename).is_file():
|
||||
raise ValueError(
|
||||
"""
|
||||
Please go to https://huggingface.co/distil-whisper/distil-small.en
|
||||
to download original-model.bin
|
||||
You can use the following command to do that:
|
||||
|
||||
wget -O distil-small-en-original-model.bin https://huggingface.co/distil-whisper/distil-small.en/resolve/main/original-model.bin
|
||||
"""
|
||||
)
|
||||
model = whisper.load_model(filename)
|
||||
else:
|
||||
model = whisper.load_model(name)
|
||||
print(model.dims)
|
||||
|
||||
@@ -209,7 +209,7 @@ class OnnxModel:
|
||||
logits = logits.reshape(-1)
|
||||
mask = torch.ones(logits.shape[0], dtype=torch.int64)
|
||||
mask[self.all_language_tokens] = 0
|
||||
logits[mask] = float("-inf")
|
||||
logits[mask != 0] = float("-inf")
|
||||
lang_id = logits.argmax().item()
|
||||
print("detected language: ", self.id2lang[lang_id])
|
||||
return lang_id
|
||||
@@ -263,7 +263,9 @@ def compute_features(filename: str) -> torch.Tensor:
|
||||
|
||||
target = 3000
|
||||
if mel.shape[0] > target:
|
||||
mel = mel[:target]
|
||||
# -50 so that there are some zero tail paddings.
|
||||
mel = mel[: target - 50]
|
||||
mel = torch.nn.functional.pad(mel, (0, 0, 0, 50), "constant", 0)
|
||||
|
||||
# We don't need to pad it to 30 seconds now!
|
||||
# mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0)
|
||||
|
||||
Reference in New Issue
Block a user