add more models for speaker diarization (#1440)

This commit is contained in:
Fangjun Kuang
2024-10-17 20:03:09 +08:00
committed by GitHub
parent 4783c8f590
commit e0586f1876
3 changed files with 53 additions and 18 deletions

View File

@@ -27,10 +27,22 @@ def get_args():
@dataclass
class SpeakerSegmentationModel:
model_name: str
short_name: str = ""
short_name: str
def get_models() -> List[SpeakerSegmentationModel]:
@dataclass
class SpeakerEmbeddingModel:
model_name: str
short_name: str
@dataclass
class Model:
segmentation: SpeakerSegmentationModel
embedding: SpeakerEmbeddingModel
def get_segmentation_models() -> List[SpeakerSegmentationModel]:
models = [
SpeakerSegmentationModel(
model_name="sherpa-onnx-pyannote-segmentation-3-0",
@@ -45,13 +57,33 @@ def get_models() -> List[SpeakerSegmentationModel]:
return models
def get_embedding_models() -> List[SpeakerEmbeddingModel]:
models = [
SpeakerSegmentationModel(
model_name="3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k",
short_name="3dspeaker",
),
SpeakerSegmentationModel(
model_name="nemo_en_titanet_small",
short_name="nemo",
),
]
return models
def main():
args = get_args()
index = args.index
total = args.total
assert 0 <= index < total, (index, total)
all_model_list = get_models()
segmentation_models = get_segmentation_models()
embedding_models = get_embedding_models()
all_model_list = []
for s in segmentation_models:
for e in embedding_models:
all_model_list.append(Model(segmentation=s, embedding=e))
num_models = len(all_model_list)