add more models for speaker diarization (#1440)
This commit is contained in:
@@ -27,10 +27,22 @@ def get_args():
|
||||
@dataclass
|
||||
class SpeakerSegmentationModel:
|
||||
model_name: str
|
||||
short_name: str = ""
|
||||
short_name: str
|
||||
|
||||
|
||||
def get_models() -> List[SpeakerSegmentationModel]:
|
||||
@dataclass
|
||||
class SpeakerEmbeddingModel:
|
||||
model_name: str
|
||||
short_name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Model:
|
||||
segmentation: SpeakerSegmentationModel
|
||||
embedding: SpeakerEmbeddingModel
|
||||
|
||||
|
||||
def get_segmentation_models() -> List[SpeakerSegmentationModel]:
|
||||
models = [
|
||||
SpeakerSegmentationModel(
|
||||
model_name="sherpa-onnx-pyannote-segmentation-3-0",
|
||||
@@ -45,13 +57,33 @@ def get_models() -> List[SpeakerSegmentationModel]:
|
||||
return models
|
||||
|
||||
|
||||
def get_embedding_models() -> List[SpeakerEmbeddingModel]:
|
||||
models = [
|
||||
SpeakerSegmentationModel(
|
||||
model_name="3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k",
|
||||
short_name="3dspeaker",
|
||||
),
|
||||
SpeakerSegmentationModel(
|
||||
model_name="nemo_en_titanet_small",
|
||||
short_name="nemo",
|
||||
),
|
||||
]
|
||||
return models
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
index = args.index
|
||||
total = args.total
|
||||
assert 0 <= index < total, (index, total)
|
||||
|
||||
all_model_list = get_models()
|
||||
segmentation_models = get_segmentation_models()
|
||||
embedding_models = get_embedding_models()
|
||||
|
||||
all_model_list = []
|
||||
for s in segmentation_models:
|
||||
for e in embedding_models:
|
||||
all_model_list.append(Model(segmentation=s, embedding=e))
|
||||
|
||||
num_models = len(all_model_list)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user