First Commit
This commit is contained in:
14
Dockerfile
Normal file
14
Dockerfile
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
FROM zibo.harbor.iluvatar.com.cn:30000/saas/bi100-3.2.1-x86-ubuntu20.04-py3.10-poc-llm-infer:v1.2.2
|
||||||
|
|
||||||
|
RUN pip install --no-cache-dir torch==2.1.0+corex.3.2.1 torchaudio==2.1.0+corex.3.2.1 pyannote.audio av
|
||||||
|
|
||||||
|
COPY ./pyannote_models/pyannote-wespeaker-voxceleb-resnet34-LM /model
|
||||||
|
|
||||||
|
COPY ./src/filesystem_storage.py /workspace/filesystem_storage.py
|
||||||
|
COPY ./src/speaker_identification.py /workspace/speaker_identification.py
|
||||||
|
COPY ./src/iflytek_interface_server.py /workspace/iflytek_interface_server.py
|
||||||
|
COPY ./launch_service /workspace/launch_service
|
||||||
|
|
||||||
|
WORKDIR /workspace/
|
||||||
|
|
||||||
|
ENTRYPOINT ["./launch_service"]
|
||||||
9
launch_service
Executable file
9
launch_service
Executable file
@@ -0,0 +1,9 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
date
|
||||||
|
cat /proc/cpuinfo | tail -n 50
|
||||||
|
ixsmi
|
||||||
|
export
|
||||||
|
date
|
||||||
|
|
||||||
|
python3 iflytek_interface_server.py
|
||||||
35
pyannote_models/pyannote-wespeaker-voxceleb-resnet34-LM/.gitattributes
vendored
Normal file
35
pyannote_models/pyannote-wespeaker-voxceleb-resnet34-LM/.gitattributes
vendored
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.model filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||||
|
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||||
BIN
src/.iflytek_interface_server.py.swp
Normal file
BIN
src/.iflytek_interface_server.py.swp
Normal file
Binary file not shown.
56
src/filesystem_storage.py
Normal file
56
src/filesystem_storage.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
BASE_DIRECTORY = './embedding_data'
|
||||||
|
|
||||||
|
class FileSystemStorage:
|
||||||
|
def __init__(self, base_directory=BASE_DIRECTORY):
|
||||||
|
self.base_dir = base_directory
|
||||||
|
|
||||||
|
def create_group(self, auth_token, group_id):
|
||||||
|
path = os.path.join(self.base_dir, group_id)
|
||||||
|
if os.path.exists(path):
|
||||||
|
raise FileExistsError(f"{group_id}")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
os.makedirs(path)
|
||||||
|
except OSError as e:
|
||||||
|
logger.error(f"Error creating directory {path}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get(self, auth_token, group_id, item_id):
|
||||||
|
try:
|
||||||
|
path = os.path.join(self.base_dir, group_id, item_id)
|
||||||
|
with open(path, 'rb') as f:
|
||||||
|
return f.read()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error reading from {path}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def save(self, auth_token, group_id, item_id, content):
|
||||||
|
try:
|
||||||
|
path = os.path.join(self.base_dir, group_id, item_id)
|
||||||
|
with open(path, 'wb') as f:
|
||||||
|
f.write(content)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving to {path}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def remove(slef, auth_token, group_id, item_id):
|
||||||
|
try:
|
||||||
|
path = os.path.join(self.base_dir, group_id, item_id)
|
||||||
|
os.remove(path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error remove item {path}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def remove_group(self, auth_token, group_id):
|
||||||
|
try:
|
||||||
|
path = os.path.join(self.base_dir, group_id)
|
||||||
|
shutil.rmtree(path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error remove group {path}: {e}")
|
||||||
|
raise
|
||||||
212
src/iflytek_interface_server.py
Normal file
212
src/iflytek_interface_server.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
import av
|
||||||
|
import io
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
import torch
|
||||||
|
import numpy
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
from flask import Flask, request, jsonify
|
||||||
|
|
||||||
|
format_str = '%(asctime)s - %(levelname)s - %(name)s - %(message)s'
|
||||||
|
datefmt= '%Y-%m-%d %H:%M:%S'
|
||||||
|
logging.basicConfig(level=logging.WARNING, format=format_str, datefmt=datefmt)
|
||||||
|
|
||||||
|
from speaker_identification import init_embedding_model, create_group, enroll_speaker, identify_speaker, calc_similarity, list_speakers, remove_speaker, remove_group
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
init_embedding_model()
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
def samples_from_raw_bytes(data):
|
||||||
|
ret = numpy.frombuffer(data, dtype=numpy.int16).astype(numpy.float32) / 32768.0
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def samples_from_lame_bytes(data):
|
||||||
|
input_buffer = io.BytesIO(data)
|
||||||
|
with av.open(input_buffer, mode='r') as container:
|
||||||
|
stream = next(s for s in container.streams if s.type == 'audio')
|
||||||
|
resampler = av.AudioResampler(
|
||||||
|
format='fltp',
|
||||||
|
layout='mono',
|
||||||
|
rate=16000
|
||||||
|
)
|
||||||
|
frame_chunks = []
|
||||||
|
for packet in container.demux(stream):
|
||||||
|
for frame in packet.decode():
|
||||||
|
for resampled_frame in resampler.resample(frame):
|
||||||
|
frame_chunks.append(resampled_frame.to_ndarray().flatten())
|
||||||
|
for resampled_frame in resampler.resample(None):
|
||||||
|
frame_chunks.append(resampled_frame.to_ndarray().flatten())
|
||||||
|
if not frame_chunks:
|
||||||
|
return numpy.array([], dtype=numpy.float32)
|
||||||
|
return numpy.concatenate(frame_chunks)
|
||||||
|
|
||||||
|
def waveform_input_from_b64_audio(audio_b64, audio_format):
|
||||||
|
audio_data = base64.b64decode(audio_b64)
|
||||||
|
if audio_format == 'lame':
|
||||||
|
samples = samples_from_lame_bytes(audio_data)
|
||||||
|
else:
|
||||||
|
samples = samples_from_raw_bytes(audio_data)
|
||||||
|
waveform = torch.from_numpy(numpy.expand_dims(samples, axis=0))
|
||||||
|
ret = {'waveform': waveform, 'sample_rate': 16000}
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def conv_group_id(aid, gid):
|
||||||
|
group_id = f'{aid}_____{gid}'
|
||||||
|
return group_id
|
||||||
|
|
||||||
|
def process_create_group(req_json):
|
||||||
|
params = req_json['parameter']['s782b4996']
|
||||||
|
aid = req_json['header']['app_id']
|
||||||
|
gid = params['groupId']
|
||||||
|
group_id = conv_group_id(aid, gid)
|
||||||
|
create_group(group_id)
|
||||||
|
resp = {
|
||||||
|
'groupName': f'{gid} (groupId)',
|
||||||
|
'groupId': gid,
|
||||||
|
'groupInfo': f'{gid} (groupId)'
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def process_create_feature(req_json):
|
||||||
|
params = req_json['parameter']['s782b4996']
|
||||||
|
aid = req_json['header']['app_id']
|
||||||
|
gid = params['groupId']
|
||||||
|
group_id = conv_group_id(aid, gid)
|
||||||
|
speaker_id = params['featureId']
|
||||||
|
audio_b64 = req_json['payload']['resource']['audio']
|
||||||
|
audio_format = req_json['payload']['resource']['encoding']
|
||||||
|
audio = waveform_input_from_b64_audio(audio_b64, audio_format)
|
||||||
|
enroll_speaker(group_id, speaker_id, audio)
|
||||||
|
resp = {'featureId': speaker_id}
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def process_feature_list(req_json):
|
||||||
|
params = req_json['parameter']['s782b4996']
|
||||||
|
aid = req_json['header']['app_id']
|
||||||
|
gid = params['groupId']
|
||||||
|
group_id = conv_group_id(aid, gid)
|
||||||
|
speaker_ids = list_speakers(group_id)
|
||||||
|
resp = [{'featureInfo': f'{speaker_id} (featureId)',
|
||||||
|
'featureId': sid}
|
||||||
|
for sid in speaker_ids]
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def process_score_feature(req_json):
|
||||||
|
params = req_json['parameter']['s782b4996']
|
||||||
|
aid = req_json['header']['app_id']
|
||||||
|
gid = params['groupId']
|
||||||
|
group_id = conv_group_id(aid, gid)
|
||||||
|
speaker_id = params['dstFeatureId']
|
||||||
|
audio_b64 = req_json['payload']['resource']['audio']
|
||||||
|
audio_format = req_json['payload']['resource']['encoding']
|
||||||
|
audio = waveform_input_from_b64_audio(audio_b64, audio_format)
|
||||||
|
similarity = calc_similarity(audio, group_id, speaker_id)
|
||||||
|
resp = {
|
||||||
|
'score': similarity,
|
||||||
|
'featureInfo': f'{speaker_id} (featureId)',
|
||||||
|
'featureId': speaker_id,
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def process_search_feature(req_json):
|
||||||
|
params = req_json['parameter']['s782b4996']
|
||||||
|
aid = req_json['header']['app_id']
|
||||||
|
gid = params['groupId']
|
||||||
|
group_id = conv_group_id(aid, gid)
|
||||||
|
top_k = params['topK']
|
||||||
|
audio_b64 = req_json['payload']['resource']['audio']
|
||||||
|
audio_format = req_json['payload']['resource']['encoding']
|
||||||
|
audio = waveform_input_from_b64_audio(audio_b64, audio_format)
|
||||||
|
iden_res = identify_speaker(audio, group_id, top_k)
|
||||||
|
score_list = []
|
||||||
|
for score, speaker_id in iden_res:
|
||||||
|
score_item = {
|
||||||
|
'score': score,
|
||||||
|
'featureInfo': f'{speaker_id} (featureId)',
|
||||||
|
'featureId': speaker_id
|
||||||
|
}
|
||||||
|
score_list.append(score_item)
|
||||||
|
resp = {'scoreList': score_list}
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def process_delete_feature(req_json):
|
||||||
|
params = req_json['parameter']['s782b4996']
|
||||||
|
aid = req_json['header']['app_id']
|
||||||
|
gid = params['groupId']
|
||||||
|
group_id = conv_group_id(aid, gid)
|
||||||
|
speaker_id = params['featureId']
|
||||||
|
remove_speaker(group_id, speaker_id)
|
||||||
|
resp = {"msg": "success"}
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def process_delete_group(req_json):
|
||||||
|
params = req_json['parameter']['s782b4996']
|
||||||
|
aid = req_json['header']['app_id']
|
||||||
|
gid = params['groupId']
|
||||||
|
group_id = conv_group_id(aid, gid)
|
||||||
|
remove_group(group_id, speaker_id)
|
||||||
|
resp = {"msg": "success"}
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def generate_interface_response(success, resp, req_id):
|
||||||
|
resp_b64 = base64.b64encode(json.dumps(resp).encode('utf-8')).decode('utf-8')
|
||||||
|
ret = {
|
||||||
|
"header": {
|
||||||
|
"code": 0 if success else 10009,
|
||||||
|
"message": "success",
|
||||||
|
"sid": req_id,
|
||||||
|
},
|
||||||
|
"payload": {
|
||||||
|
"updateFeatureRes": {
|
||||||
|
"status": "3",
|
||||||
|
"text": resp_b64
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@app.route('/v1/private/s782b4996', methods=['POST'])
|
||||||
|
def s782b4996():
|
||||||
|
req_id = str(uuid.uuid4())
|
||||||
|
try:
|
||||||
|
req_json = request.json
|
||||||
|
func = req_json['parameter']['s782b4996']['func']
|
||||||
|
logger.info(f'Processing request {func=}, {req_id=}...')
|
||||||
|
resp = None
|
||||||
|
ts_beg = time.time()
|
||||||
|
if func == 'createGroup':
|
||||||
|
resp = process_create_group(req_json)
|
||||||
|
elif func == 'createFeature':
|
||||||
|
resp = process_create_feature(req_json)
|
||||||
|
elif func == 'queryFeatureList':
|
||||||
|
resp = process_feature_list(req_json)
|
||||||
|
elif func == 'searchScoreFea':
|
||||||
|
resp = process_score_feature(req_json)
|
||||||
|
elif func == 'searchFea':
|
||||||
|
resp = process_search_feature(req_json)
|
||||||
|
elif func == 'deleteFeature':
|
||||||
|
resp = process_delete_feature(req_json)
|
||||||
|
elif func == 'deleteGroup':
|
||||||
|
resp = process_delete_group(req_json)
|
||||||
|
elapsed = time.time() - ts_beg
|
||||||
|
logger.debug(f'{elapsed=:.3f}s Result = {resp}')
|
||||||
|
logger.info(f'Request {req_id} completed.')
|
||||||
|
ret = generate_interface_response(True, resp, req_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f'Exception {e}', exc_info=True)
|
||||||
|
msg = {'error_msg': str(e)}
|
||||||
|
ret = generate_interface_response(False, msg, req_id)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@app.route('/health')
|
||||||
|
@app.route('/health_check')
|
||||||
|
def health():
|
||||||
|
return {'status': 'ok'}
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run(host='0.0.0.0', port=80)
|
||||||
172
src/speaker_identification.py
Normal file
172
src/speaker_identification.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import pickle
|
||||||
|
import logging
|
||||||
|
import hashlib
|
||||||
|
import numpy as np
|
||||||
|
from pyannote.audio import Model, Inference
|
||||||
|
from filesystem_storage import FileSystemStorage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
STOR_AUTH_TOKEN = ''
|
||||||
|
EMB_MODEL_PATH = "/model/pytorch_model.bin"
|
||||||
|
embedding_model = None
|
||||||
|
storage = FileSystemStorage()
|
||||||
|
|
||||||
|
class SpeakerIDException(Exception):
|
||||||
|
def get_err_msg(self):
|
||||||
|
return str(self)
|
||||||
|
|
||||||
|
def _safe_id(id):
|
||||||
|
original_id = str(id)
|
||||||
|
allowed_chars = set('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_')
|
||||||
|
cleaned_chars = (c if c in allowed_chars else '_' for c in original_id)
|
||||||
|
cleaned_id = "".join(cleaned_chars)
|
||||||
|
truncated_id = cleaned_id[:300]
|
||||||
|
|
||||||
|
hasher = hashlib.sha256()
|
||||||
|
hasher.update(original_id.encode('utf-8'))
|
||||||
|
full_hash = hasher.hexdigest()
|
||||||
|
short_hash = full_hash[:16]
|
||||||
|
|
||||||
|
final_id = f"{truncated_id}_{short_hash}"
|
||||||
|
return final_id
|
||||||
|
|
||||||
|
def _calc_embedding(audio):
|
||||||
|
global embedding_model
|
||||||
|
try:
|
||||||
|
embedding = embedding_model(audio)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Could not calc embedding: inference audio error {e}')
|
||||||
|
raise
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def _cosine_similarity(u, v):
|
||||||
|
u = np.asarray(u)
|
||||||
|
v = np.asarray(v)
|
||||||
|
|
||||||
|
if np.linalg.norm(u) == 0 or np.linalg.norm(v) == 0:
|
||||||
|
logger.warning("Warning: One or both vectors are zero vectors. Cosine distance is undefined.")
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
dot_product = np.dot(u, v)
|
||||||
|
norm_u = np.linalg.norm(u)
|
||||||
|
norm_v = np.linalg.norm(v)
|
||||||
|
|
||||||
|
similarity = float(dot_product / (norm_u * norm_v))
|
||||||
|
similarity = max(0.0, min(similarity, 1.0))
|
||||||
|
return similarity
|
||||||
|
|
||||||
|
def _load_group(group_id):
|
||||||
|
try:
|
||||||
|
emb_content = storage.get(STOR_AUTH_TOKEN, group_id, 'embeddings')
|
||||||
|
embeddings = pickle.loads(emb_content)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Could not load group with id {group_id}, err = {e}')
|
||||||
|
raise
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def _save_group(group_id, embeddings):
|
||||||
|
global storage
|
||||||
|
try:
|
||||||
|
emb_content = pickle.dumps(embeddings)
|
||||||
|
storage.save(STOR_AUTH_TOKEN, group_id, 'embeddings', emb_content)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Could not save group with id {group_id}, err = {e}')
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _create_group(group_id):
|
||||||
|
try:
|
||||||
|
storage.create_group(STOR_AUTH_TOKEN, group_id_s)
|
||||||
|
_save_group(group_id_s, {})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Could not create group with id {group_id}, err = {e}')
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Public Functions
|
||||||
|
def init_embedding_model():
|
||||||
|
global embedding_model
|
||||||
|
model = Model.from_pretrained(EMB_MODEL_PATH)
|
||||||
|
embedding_model = Inference(model, window="whole")
|
||||||
|
embedding_model.to(torch.device('cuda'))
|
||||||
|
|
||||||
|
def create_group(group_id):
|
||||||
|
global storage
|
||||||
|
try:
|
||||||
|
group_id_s = _safe_id(group_id)
|
||||||
|
storage.create_group(STOR_AUTH_TOKEN, group_id_s)
|
||||||
|
_save_group(group_id_s, {})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Could not create group with id {group_id}, err = {e}')
|
||||||
|
raise SpeakerIDException(f'Create Group failed')
|
||||||
|
|
||||||
|
def enroll_speaker(group_id, speaker_id, audio):
|
||||||
|
try:
|
||||||
|
group_id_s = _safe_id(group_id)
|
||||||
|
speaker_id_s = _safe_id(speaker_id)
|
||||||
|
embeddings = _load_group(group_id_s)
|
||||||
|
speaker_emb = _calc_embedding(audio)
|
||||||
|
embeddings[speaker_id_s] = (speaker_id, speaker_emb)
|
||||||
|
_save_group(group_id_s, embeddings)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Could enroll speaker with {group_id=} {speaker_id=}, err = {e}')
|
||||||
|
raise SpeakerIDException(f'Enroll Speaker failed')
|
||||||
|
|
||||||
|
def list_speakers(group_id):
|
||||||
|
try:
|
||||||
|
group_id_s = _safe_id(group_id)
|
||||||
|
embeddings = _load_group(group_id_s)
|
||||||
|
speaker_ids = [x[0] for x in embeddings.values()]
|
||||||
|
return speaker_ids
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Could not list speakers. {group_id=}, err = {e}')
|
||||||
|
raise SpeakerIDException(f'List Speakers failed')
|
||||||
|
|
||||||
|
def remove_speaker(group_id, speaker_id):
|
||||||
|
try:
|
||||||
|
group_id_s = _safe_id(group_id)
|
||||||
|
speaker_id_s = _safe_id(speaker_id)
|
||||||
|
embeddings = _load_group(group_id_s)
|
||||||
|
embeddings.pop(speaker_id_s)
|
||||||
|
_save_group(group_id_s, embeddings)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Could not remove speaker. {group_id=} {speaker_id = }, err = {e}')
|
||||||
|
raise SpeakerIDException(f'Remove Speaker failed')
|
||||||
|
|
||||||
|
def remove_group(group_id):
|
||||||
|
try:
|
||||||
|
group_id_s = _safe_id(group_id)
|
||||||
|
storage.remove_group(STOR_AUTH_TOKEN, group_id_s)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Could not remove group {group_id}, err = {e}')
|
||||||
|
raise SpeakerIDException(f'Remove Group failed')
|
||||||
|
|
||||||
|
def calc_similarity(audio, group_id, speaker_id):
|
||||||
|
try:
|
||||||
|
group_id_s = _safe_id(group_id)
|
||||||
|
speaker_id_s = _safe_id(speaker_id)
|
||||||
|
embeddings = _load_group(group_id_s)
|
||||||
|
speaker_emb = embeddings[speaker_id_s][1]
|
||||||
|
audio_emb = _calc_embedding(audio)
|
||||||
|
similarity = _cosine_similarity(speaker_emb, audio_emb)
|
||||||
|
return similarity
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Could not calculate similarity. {group_id=} {speaker_id=}: {e}')
|
||||||
|
raise SpeakerIDException(f'Calculate Similarity failed')
|
||||||
|
|
||||||
|
def identify_speaker(audio, group_id, top_k):
|
||||||
|
try:
|
||||||
|
group_id_s = _safe_id(group_id)
|
||||||
|
embeddings = _load_group(group_id_s)
|
||||||
|
audio_emb = _calc_embedding(audio)
|
||||||
|
ret_lst = []
|
||||||
|
for i, e in embeddings.values():
|
||||||
|
similarity = _cosine_similarity(audio_emb, e)
|
||||||
|
ret_lst.append((similarity, i))
|
||||||
|
ret_lst.sort(reverse=True)
|
||||||
|
ret_lst = ret_lst[:top_k]
|
||||||
|
return ret_lst
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Could not identify speaker. {group_id=}: {e}')
|
||||||
|
raise SpeakerIDException(f'Identify Speaker failed')
|
||||||
Reference in New Issue
Block a user