Add CI test for Whisper models (#239)
This commit is contained in:
@@ -18,15 +18,30 @@ import argparse
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
"--encoder",
|
||||
type=str,
|
||||
required=True,
|
||||
# fmt: off
|
||||
choices=[
|
||||
"tiny", "tiny.en", "base", "base.en",
|
||||
"small", "small.en", "medium", "medium.en",
|
||||
"large", "large-v1", "large-v2"],
|
||||
# fmt: on
|
||||
help="Path to the encoder",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the tokens",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_file",
|
||||
type=str,
|
||||
help="Path to the test wave",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -161,11 +176,10 @@ def load_tokens(filename):
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
name = args.model
|
||||
encoder = args.encoder
|
||||
decoder = args.decoder
|
||||
|
||||
encoder = f"./{name}-encoder.onnx"
|
||||
decoder = f"./{name}-decoder.onnx"
|
||||
audio = whisper.load_audio("0.wav")
|
||||
audio = whisper.load_audio(args.sound_file)
|
||||
|
||||
features = []
|
||||
online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions())
|
||||
@@ -224,17 +238,13 @@ def main():
|
||||
logits = logits[0, -1]
|
||||
model.suppress_tokens(logits, is_initial=False)
|
||||
max_token_id = logits.argmax(dim=-1)
|
||||
token_table = load_tokens(f"./{name}-tokens.txt")
|
||||
token_table = load_tokens(args.tokens)
|
||||
s = b""
|
||||
for i in results:
|
||||
if i in token_table:
|
||||
s += base64.b64decode(token_table[i])
|
||||
else:
|
||||
print("oov", i)
|
||||
|
||||
print(s.decode().strip())
|
||||
print(results)
|
||||
print(model.sot_sequence)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user