Set batch size to 1 for more streaming ASR models (#1280)

This commit is contained in:
Fangjun Kuang
2024-08-23 11:06:55 +08:00
committed by GitHub
parent c61423ec5a
commit fb09f8fae3
15 changed files with 782 additions and 38 deletions

View File

@@ -2,7 +2,7 @@
import argparse
from dataclasses import dataclass
from typing import List, Optional
from typing import List
import jinja2
@@ -34,76 +34,99 @@ class SpeakerIdentificationModel:
def get_3dspeaker_models() -> List[SpeakerIdentificationModel]:
models = [
SpeakerIdentificationModel(model_name="3dspeaker_speech_campplus_sv_en_voxceleb_16k.onnx"),
SpeakerIdentificationModel(model_name="3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx"),
SpeakerIdentificationModel(model_name="3dspeaker_speech_eres2net_base_200k_sv_zh-cn_16k-common.onnx"),
SpeakerIdentificationModel(model_name="3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx"),
SpeakerIdentificationModel(model_name="3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx"),
SpeakerIdentificationModel(model_name="3dspeaker_speech_eres2net_sv_en_voxceleb_16k.onnx"),
SpeakerIdentificationModel(model_name="3dspeaker_speech_eres2net_sv_zh-cn_16k-common.onnx"),
SpeakerIdentificationModel(
model_name="3dspeaker_speech_campplus_sv_en_voxceleb_16k.onnx"
),
SpeakerIdentificationModel(
model_name="3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx"
),
SpeakerIdentificationModel(
model_name="3dspeaker_speech_eres2net_base_200k_sv_zh-cn_16k-common.onnx"
),
SpeakerIdentificationModel(
model_name="3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx"
),
SpeakerIdentificationModel(
model_name="3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx"
),
SpeakerIdentificationModel(
model_name="3dspeaker_speech_eres2net_sv_en_voxceleb_16k.onnx"
),
SpeakerIdentificationModel(
model_name="3dspeaker_speech_eres2net_sv_zh-cn_16k-common.onnx"
),
]
prefix = '3dspeaker_speech_'
prefix = "3dspeaker_speech_"
num = len(prefix)
for m in models:
m.framework = '3dspeaker'
m.framework = "3dspeaker"
m.short_name = m.model_name[num:-5]
if '_zh-cn_' in m.model_name:
m.lang = 'zh'
elif '_en_' in m.model_name:
m.lang = 'en'
if "_zh-cn_" in m.model_name:
m.lang = "zh"
elif "_en_" in m.model_name:
m.lang = "en"
else:
raise ValueError(m)
return models
def get_wespeaker_models() -> List[SpeakerIdentificationModel]:
models = [
SpeakerIdentificationModel(model_name="wespeaker_en_voxceleb_CAM++.onnx"),
SpeakerIdentificationModel(model_name="wespeaker_en_voxceleb_CAM++_LM.onnx"),
SpeakerIdentificationModel(model_name="wespeaker_en_voxceleb_resnet152_LM.onnx"),
SpeakerIdentificationModel(model_name="wespeaker_en_voxceleb_resnet221_LM.onnx"),
SpeakerIdentificationModel(model_name="wespeaker_en_voxceleb_resnet293_LM.onnx"),
SpeakerIdentificationModel(
model_name="wespeaker_en_voxceleb_resnet152_LM.onnx"
),
SpeakerIdentificationModel(
model_name="wespeaker_en_voxceleb_resnet221_LM.onnx"
),
SpeakerIdentificationModel(
model_name="wespeaker_en_voxceleb_resnet293_LM.onnx"
),
SpeakerIdentificationModel(model_name="wespeaker_en_voxceleb_resnet34.onnx"),
SpeakerIdentificationModel(model_name="wespeaker_en_voxceleb_resnet34_LM.onnx"),
SpeakerIdentificationModel(model_name="wespeaker_zh_cnceleb_resnet34.onnx"),
SpeakerIdentificationModel(model_name="wespeaker_zh_cnceleb_resnet34_LM.onnx"),
]
prefix = 'wespeaker_xx_'
prefix = "wespeaker_xx_"
num = len(prefix)
for m in models:
m.framework = 'wespeaker'
m.framework = "wespeaker"
m.short_name = m.model_name[num:-5]
if '_zh_' in m.model_name:
m.lang = 'zh'
elif '_en_' in m.model_name:
m.lang = 'en'
if "_zh_" in m.model_name:
m.lang = "zh"
elif "_en_" in m.model_name:
m.lang = "en"
else:
raise ValueError(m)
return models
def get_nemo_models() -> List[SpeakerIdentificationModel]:
models = [
SpeakerIdentificationModel(model_name="nemo_en_speakerverification_speakernet.onnx"),
SpeakerIdentificationModel(
model_name="nemo_en_speakerverification_speakernet.onnx"
),
SpeakerIdentificationModel(model_name="nemo_en_titanet_large.onnx"),
SpeakerIdentificationModel(model_name="nemo_en_titanet_small.onnx"),
]
prefix = 'nemo_en_'
prefix = "nemo_en_"
num = len(prefix)
for m in models:
m.framework = 'nemo'
m.framework = "nemo"
m.short_name = m.model_name[num:-5]
if '_zh_' in m.model_name:
m.lang = 'zh'
elif '_en_' in m.model_name:
m.lang = 'en'
if "_zh_" in m.model_name:
m.lang = "zh"
elif "_en_" in m.model_name:
m.lang = "en"
else:
raise ValueError(m)
return models
def main():
args = get_args()
index = args.index