Files
enginex-bi_series-vc-cnn/utils/client_async.py

278 lines
11 KiB
Python
Raw Permalink Normal View History

2025-08-06 15:38:55 +08:00
import asyncio
import json
import os
import time
import traceback
from copy import deepcopy
from enum import Enum
from typing import Any, List
import websockets
from pydantic_core import ValidationError
from schemas.context import ASRContext
from schemas.stream import StreamResultModel, StreamWordsModel
from utils.logger import logger
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", None) is None
class STATUS_DATA(str, Enum):
WAITING_FIRST_INIT = "waiting_first_init"
FIRST_FAIL = "fail"
WAITING_SECOND_INIT = "waiting_second_init"
SECOND_INIT = "second_fail"
WAITING_THIRD_INIT = "waiting_third_init"
THIRD_INIT = "third_fail"
SUCCESS = "success"
CLOSED = "closed"
class ClientAsync:
def __init__(self, sut_url: str, context: ASRContext, idx: int) -> None:
# base_url = "ws://127.0.0.1:5003"
self.base_url = sut_url + "/recognition"
self.context: ASRContext = deepcopy(context)
self.idx = idx
# if not os.getenv("DATASET_FILEPATH", ""):
# self.base_url = "wss://speech.4paradigm.com/aibuds/api/v1/recognition"
# self.base_url = "ws://localhost:5003/recognition"
self.fail_count = 0
self.close_time = 10**50
self.send_time: List[float] = []
self.recv_time: List[float] = []
self.predict_data: List[Any] = []
async def _sender(
self, websocket: websockets.WebSocketClientProtocol, send_queue: asyncio.Queue, recv_queue: asyncio.Queue
):
# 设置 websocket 缓冲区大小
websocket.transport.set_write_buffer_limits(1024 * 1024 * 1024)
# 发送初始化数据
await websocket.send(json.dumps(self._gen_init_data()))
await send_queue.put(STATUS_DATA.WAITING_FIRST_INIT)
connect_status = await recv_queue.get()
if connect_status == STATUS_DATA.FIRST_FAIL:
return
# 开始发送音频
with open(self.context.file_path, "rb") as fp:
wav_data = fp.read()
meta_length = wav_data.index(b"data") + 8
try:
with open(self.context.file_path, "rb") as fp:
# 去掉 wav 文件的头信息
fp.read(meta_length)
wav_time = 0.0
label_id = 0
char_contains_rate_checktime = []
char_contains_rate_checktime_id = 0
while True:
now_time = time.perf_counter()
chunk = fp.read(int(self.context.chunk_size))
if not chunk:
break
wav_time += self.context.wait_time
try:
self.send_time.append(time.perf_counter())
await asyncio.wait_for(websocket.send(chunk), timeout=0.08)
except asyncio.exceptions.TimeoutError:
pass
while label_id < len(self.context.labels) and wav_time >= self.context.labels[label_id].start:
char_contains_rate_checktime.append(now_time + 3.0)
label_id += 1
predict_text_len = sum(map(lambda x: len(x.text), self.predict_data))
while char_contains_rate_checktime_id < len(char_contains_rate_checktime) and \
char_contains_rate_checktime[char_contains_rate_checktime_id] <= now_time:
label_text_len = sum(
map(lambda x: len(x.answer),
self.context.labels[:char_contains_rate_checktime_id+1]))
if predict_text_len / self.context.char_contains_rate < label_text_len:
self.context.fail_char_contains_rate_num += 1
char_contains_rate_checktime_id += 1
await asyncio.sleep(max(0, self.context.wait_time - (time.perf_counter() - now_time)))
await websocket.send(json.dumps({"end": True}))
logger.info(f"{self.idx} 条数据,当条语音数据发送完成")
logger.info(f"{self.idx} 条数据3s 后关闭双向连接.")
self.close_time = time.perf_counter() + 3
except websockets.exceptions.ConnectionClosedError:
logger.error(f"{self.idx} 条数据发送过程中,连接断开")
except Exception:
logger.error(f"{traceback.print_exc()}")
logger.error(f"{self.idx} 条数据,发送数据失败")
async def _recv(
self, websocket: websockets.WebSocketClientProtocol, send_queue: asyncio.Queue, recv_queue: asyncio.Queue
):
await recv_queue.get()
try:
await asyncio.wait_for(websocket.recv(), timeout=2)
except asyncio.exceptions.TimeoutError:
await send_queue.put(STATUS_DATA.FIRST_FAIL)
logger.info(f"{self.idx} 条数据,初始化阶段, 2s 没收到 success 返回,超时了")
self.fail_count += 1
return
except Exception as e:
await send_queue.put(STATUS_DATA.FIRST_FAIL)
logger.error(f"{self.idx} 条数据,初始化阶段, 收到异常:{e}")
self.fail_count += 1
return
else:
await send_queue.put(STATUS_DATA.SUCCESS)
# 开始接收语音识别结果
try:
while websocket.open:
# 接收数据
recv_data = await websocket.recv()
if isinstance(recv_data, str):
self.recv_time.append(time.perf_counter())
recv_data = str(recv_data)
recv_data = json.loads(recv_data)
result = StreamResultModel(**recv_data)
recognition_results = result.asr_results
if (
recognition_results.final_result
and not recognition_results.language
and recognition_results.start_time == 0
and recognition_results.end_time == 0
and recognition_results.para_seq == 0
):
pass
else:
self.predict_data.append(recognition_results)
else:
raise Exception("返回的结果不是字符串形式")
except websockets.exceptions.ConnectionClosedOK:
pass
except websockets.exceptions.ConnectionClosedError:
pass
except ValidationError as e:
logger.error(f"{self.idx} 条数据,返回的结果不符合格式")
logger.error(f"Exception is {e}")
os._exit(1)
except OSError:
pass
except Exception:
logger.error(f"{traceback.print_exc()}")
logger.error(f"{self.idx} 条数据,处理被测服务返回数据时出错")
async def _action(self):
logger.info(f"{self.idx} 条数据开始测试")
while self.fail_count < 3:
send_queue = asyncio.Queue()
recv_queue = asyncio.Queue()
self.send_time: List[float] = []
self.recv_time: List[float] = []
self.predict_data: List[Any] = []
async with websockets.connect(self.base_url) as websocket:
send_task = asyncio.create_task(self._sender(websocket, send_queue, recv_queue))
recv_task = asyncio.create_task(self._recv(websocket, recv_queue, send_queue))
await asyncio.gather(send_task)
await asyncio.sleep(3)
await asyncio.gather(recv_task)
if self.send_time:
break
else:
self.fail_count += 1
logger.info(f"{self.idx} 条数据,初始化阶段, 第 {self.fail_count} 次失败, 1s 后重试")
time.sleep(1)
def action(self):
asyncio.run(self._action())
return self._gen_result()
def _gen_result(self) -> ASRContext:
if not self.predict_data:
logger.error(f"{self.idx} 条数据,没有任何数据返回")
self.context.append_preds(self.predict_data, self.send_time, self.recv_time)
self.context.fail = not self.predict_data
punctuation_words: List[StreamWordsModel] = []
for pred in self.predict_data:
punctuations = [",", ".", "!", "?"]
if pred.language == "zh":
punctuations = ["", "", "", ""]
elif pred.language == "ja":
punctuations = ["", "", "", ""]
elif pred.language in ("ar", "fa"):
punctuations = ["،", ".", "!", "؟"]
elif pred.language == "el":
punctuations = [",", ".", "", ""]
elif pred.language == "ti":
punctuations = [""]
for word in pred.words:
if word.text in punctuations:
punctuation_words.append(word)
start_times = list(map(lambda x: x.start_time, punctuation_words))
start_times = sorted(start_times)
end_times = list(map(lambda x: x.end_time, punctuation_words))
end_times = sorted(end_times)
self.context.punctuation_num = len(self.context.labels)
label_n = len(self.context.labels)
for i, label in enumerate(self.context.labels):
label_left = (label.end - 0.7)
label_right = (label.end + 0.7)
if i < label_n - 1:
label_left = label.end
label_right = self.context.labels[i+1].start
exist = False
def upper_bound(x: float, lst: List[float]) -> int:
ans = -1
left, right = 0, len(lst) - 1
while left <= right:
mid = (left + right) // 2
if lst[mid] >= x:
ans = mid
right = mid - 1
else:
left = mid + 1
return ans
def lower_bound(x: float, lst: List[float]) -> int:
ans = -1
left, right = 0, len(lst) - 1
while left <= right:
mid = (left + right) // 2
if lst[mid] <= x:
ans = mid
left = mid + 1
else:
right = mid - 1
return ans
left_in_pred = upper_bound(label_left, start_times)
if left_in_pred != -1 and start_times[left_in_pred] <= label_right:
exist = True
right_in_pred = lower_bound(label_right, end_times)
if right_in_pred != -1 and end_times[right_in_pred] >= label_left:
exist = True
if exist:
self.context.pred_punctuation_num += 1
return self.context
def _gen_init_data(self) -> dict:
return {
"parameter": {
"lang": None,
"sample_rate": self.context.sample_rate,
"channel": self.context.channel,
"format": self.context.audio_format,
"bits": self.context.bits,
"enable_words": self.context.enable_words,
}
}