Support distil-whisper (#411)

This commit is contained in:
Fangjun Kuang
2023-11-06 22:33:39 +08:00
committed by GitHub
parent 86baf43c6b
commit a65cdc3d76
2 changed files with 53 additions and 5 deletions

View File

@@ -39,7 +39,9 @@ def get_args():
choices=[
"tiny", "tiny.en", "base", "base.en",
"small", "small.en", "medium", "medium.en",
"large", "large-v1", "large-v2"],
"large", "large-v1", "large-v2",
"distil-medium.en",
],
# fmt: on
)
return parser.parse_args()
@@ -257,10 +259,27 @@ def convert_tokens(name, model):
def main():
args = get_args()
name = args.model
print(args)
print(name)
opset_version = 13
model = whisper.load_model(name)
if name == "distil-medium.en":
filename = "./distil-medium-en-original-model.bin"
if not Path(filename):
raise ValueError(
"""
Please go to https://huggingface.co/distil-whisper/distil-medium.en
to download original-model.bin
You can use the following command to do that:
wget -O distil-medium-en-original-model.bin https://huggingface.co/distil-whisper/distil-medium.en/resolve/main/original-model.bin
"""
)
model = whisper.load_model(filename)
else:
model = whisper.load_model(name)
print(
f"number of model parameters: {name}",
sum(p.numel() for p in model.parameters()),