#!/usr/bin/env python3 import argparse from dataclasses import dataclass import jinja2 def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--total", type=int, default=1, help="Number of runners", ) parser.add_argument( "--index", type=int, default=0, help="Index of the current runner", ) return parser.parse_args() @dataclass class SlidModel: model_name: str idx: int short_name: str = "" def get_models(): # see https://k2-fsa.github.io/sherpa/onnx/spolken-language-identification/pretrained_models.html#pre-trained-models whisper_models = [ SlidModel( model_name="sherpa-onnx-whisper-tiny", idx=0, short_name="whisper_tiny", ), ] return whisper_models def main(): args = get_args() index = args.index total = args.total assert 0 <= index < total, (index, total) all_model_list = get_models() num_models = len(all_model_list) num_per_runner = num_models // total if num_per_runner <= 0: raise ValueError(f"num_models: {num_models}, num_runners: {total}") start = index * num_per_runner end = start + num_per_runner remaining = num_models - args.total * num_per_runner print(f"{index}/{total}: {start}-{end}/{num_models}") d = dict() d["model_list"] = all_model_list[start:end] if index < remaining: s = args.total * num_per_runner + index d["model_list"].append(all_model_list[s]) print(f"{s}/{num_models}") filename_list = [ "./build-apk-slid.sh", ] for filename in filename_list: environment = jinja2.Environment() with open(f"{filename}.in") as f: s = f.read() template = environment.from_string(s) s = template.render(**d) with open(filename, "w") as f: print(s, file=f) if __name__ == "__main__": main()