Add CI test for Whisper models (#239)

This commit is contained in:
Fangjun Kuang
2023-08-07 19:24:52 +08:00
committed by GitHub
parent 45b9d4ab37
commit f7c05b1570
8 changed files with 155 additions and 16 deletions

View File

@@ -18,15 +18,30 @@ import argparse
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
"--encoder",
type=str,
required=True,
# fmt: off
choices=[
"tiny", "tiny.en", "base", "base.en",
"small", "small.en", "medium", "medium.en",
"large", "large-v1", "large-v2"],
# fmt: on
help="Path to the encoder",
)
parser.add_argument(
"--decoder",
type=str,
required=True,
help="Path to the decoder",
)
parser.add_argument(
"--tokens",
type=str,
required=True,
help="Path to the tokens",
)
parser.add_argument(
"sound_file",
type=str,
help="Path to the test wave",
)
return parser.parse_args()
@@ -161,11 +176,10 @@ def load_tokens(filename):
def main():
args = get_args()
name = args.model
encoder = args.encoder
decoder = args.decoder
encoder = f"./{name}-encoder.onnx"
decoder = f"./{name}-decoder.onnx"
audio = whisper.load_audio("0.wav")
audio = whisper.load_audio(args.sound_file)
features = []
online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions())
@@ -224,17 +238,13 @@ def main():
logits = logits[0, -1]
model.suppress_tokens(logits, is_initial=False)
max_token_id = logits.argmax(dim=-1)
token_table = load_tokens(f"./{name}-tokens.txt")
token_table = load_tokens(args.tokens)
s = b""
for i in results:
if i in token_table:
s += base64.b64decode(token_table[i])
else:
print("oov", i)
print(s.decode().strip())
print(results)
print(model.sot_sequence)
if __name__ == "__main__":