Support Giga AM transducer V2 (#2136)
This commit is contained in:
11
scripts/nemo/GigaAM/export-onnx-ctc-v2.py
Normal file → Executable file
11
scripts/nemo/GigaAM/export-onnx-ctc-v2.py
Normal file → Executable file
@@ -1,3 +1,4 @@
|
||||
#!/usr/bin/env python3
|
||||
import gigaam
|
||||
import onnx
|
||||
import torch
|
||||
@@ -27,7 +28,13 @@ def add_meta_data(filename: str, meta_data: dict[str, str]):
|
||||
|
||||
def main() -> None:
|
||||
model_name = "v2_ctc"
|
||||
model = gigaam.load_model(model_name, fp16_encoder=False, use_flash=False, download_root=".")
|
||||
model = gigaam.load_model(
|
||||
model_name, fp16_encoder=False, use_flash=False, download_root="."
|
||||
)
|
||||
|
||||
# use characters
|
||||
# space is 0
|
||||
# <blk> is the last token
|
||||
with open("./tokens.txt", "w", encoding="utf-8") as f:
|
||||
for i, s in enumerate(model.cfg["labels"]):
|
||||
f.write(f"{s} {i}\n")
|
||||
@@ -53,5 +60,5 @@ def main() -> None:
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user