Files
enginex-bi_series-vc-cnn/utils/client_async.py
zhousha 55a67e817e update
2025-08-06 15:38:55 +08:00

278 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import json
import os
import time
import traceback
from copy import deepcopy
from enum import Enum
from typing import Any, List
import websockets
from pydantic_core import ValidationError
from schemas.context import ASRContext
from schemas.stream import StreamResultModel, StreamWordsModel
from utils.logger import logger
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", None) is None
class STATUS_DATA(str, Enum):
WAITING_FIRST_INIT = "waiting_first_init"
FIRST_FAIL = "fail"
WAITING_SECOND_INIT = "waiting_second_init"
SECOND_INIT = "second_fail"
WAITING_THIRD_INIT = "waiting_third_init"
THIRD_INIT = "third_fail"
SUCCESS = "success"
CLOSED = "closed"
class ClientAsync:
def __init__(self, sut_url: str, context: ASRContext, idx: int) -> None:
# base_url = "ws://127.0.0.1:5003"
self.base_url = sut_url + "/recognition"
self.context: ASRContext = deepcopy(context)
self.idx = idx
# if not os.getenv("DATASET_FILEPATH", ""):
# self.base_url = "wss://speech.4paradigm.com/aibuds/api/v1/recognition"
# self.base_url = "ws://localhost:5003/recognition"
self.fail_count = 0
self.close_time = 10**50
self.send_time: List[float] = []
self.recv_time: List[float] = []
self.predict_data: List[Any] = []
async def _sender(
self, websocket: websockets.WebSocketClientProtocol, send_queue: asyncio.Queue, recv_queue: asyncio.Queue
):
# 设置 websocket 缓冲区大小
websocket.transport.set_write_buffer_limits(1024 * 1024 * 1024)
# 发送初始化数据
await websocket.send(json.dumps(self._gen_init_data()))
await send_queue.put(STATUS_DATA.WAITING_FIRST_INIT)
connect_status = await recv_queue.get()
if connect_status == STATUS_DATA.FIRST_FAIL:
return
# 开始发送音频
with open(self.context.file_path, "rb") as fp:
wav_data = fp.read()
meta_length = wav_data.index(b"data") + 8
try:
with open(self.context.file_path, "rb") as fp:
# 去掉 wav 文件的头信息
fp.read(meta_length)
wav_time = 0.0
label_id = 0
char_contains_rate_checktime = []
char_contains_rate_checktime_id = 0
while True:
now_time = time.perf_counter()
chunk = fp.read(int(self.context.chunk_size))
if not chunk:
break
wav_time += self.context.wait_time
try:
self.send_time.append(time.perf_counter())
await asyncio.wait_for(websocket.send(chunk), timeout=0.08)
except asyncio.exceptions.TimeoutError:
pass
while label_id < len(self.context.labels) and wav_time >= self.context.labels[label_id].start:
char_contains_rate_checktime.append(now_time + 3.0)
label_id += 1
predict_text_len = sum(map(lambda x: len(x.text), self.predict_data))
while char_contains_rate_checktime_id < len(char_contains_rate_checktime) and \
char_contains_rate_checktime[char_contains_rate_checktime_id] <= now_time:
label_text_len = sum(
map(lambda x: len(x.answer),
self.context.labels[:char_contains_rate_checktime_id+1]))
if predict_text_len / self.context.char_contains_rate < label_text_len:
self.context.fail_char_contains_rate_num += 1
char_contains_rate_checktime_id += 1
await asyncio.sleep(max(0, self.context.wait_time - (time.perf_counter() - now_time)))
await websocket.send(json.dumps({"end": True}))
logger.info(f"{self.idx} 条数据,当条语音数据发送完成")
logger.info(f"{self.idx} 条数据3s 后关闭双向连接.")
self.close_time = time.perf_counter() + 3
except websockets.exceptions.ConnectionClosedError:
logger.error(f"{self.idx} 条数据发送过程中,连接断开")
except Exception:
logger.error(f"{traceback.print_exc()}")
logger.error(f"{self.idx} 条数据,发送数据失败")
async def _recv(
self, websocket: websockets.WebSocketClientProtocol, send_queue: asyncio.Queue, recv_queue: asyncio.Queue
):
await recv_queue.get()
try:
await asyncio.wait_for(websocket.recv(), timeout=2)
except asyncio.exceptions.TimeoutError:
await send_queue.put(STATUS_DATA.FIRST_FAIL)
logger.info(f"{self.idx} 条数据,初始化阶段, 2s 没收到 success 返回,超时了")
self.fail_count += 1
return
except Exception as e:
await send_queue.put(STATUS_DATA.FIRST_FAIL)
logger.error(f"{self.idx} 条数据,初始化阶段, 收到异常:{e}")
self.fail_count += 1
return
else:
await send_queue.put(STATUS_DATA.SUCCESS)
# 开始接收语音识别结果
try:
while websocket.open:
# 接收数据
recv_data = await websocket.recv()
if isinstance(recv_data, str):
self.recv_time.append(time.perf_counter())
recv_data = str(recv_data)
recv_data = json.loads(recv_data)
result = StreamResultModel(**recv_data)
recognition_results = result.asr_results
if (
recognition_results.final_result
and not recognition_results.language
and recognition_results.start_time == 0
and recognition_results.end_time == 0
and recognition_results.para_seq == 0
):
pass
else:
self.predict_data.append(recognition_results)
else:
raise Exception("返回的结果不是字符串形式")
except websockets.exceptions.ConnectionClosedOK:
pass
except websockets.exceptions.ConnectionClosedError:
pass
except ValidationError as e:
logger.error(f"{self.idx} 条数据,返回的结果不符合格式")
logger.error(f"Exception is {e}")
os._exit(1)
except OSError:
pass
except Exception:
logger.error(f"{traceback.print_exc()}")
logger.error(f"{self.idx} 条数据,处理被测服务返回数据时出错")
async def _action(self):
logger.info(f"{self.idx} 条数据开始测试")
while self.fail_count < 3:
send_queue = asyncio.Queue()
recv_queue = asyncio.Queue()
self.send_time: List[float] = []
self.recv_time: List[float] = []
self.predict_data: List[Any] = []
async with websockets.connect(self.base_url) as websocket:
send_task = asyncio.create_task(self._sender(websocket, send_queue, recv_queue))
recv_task = asyncio.create_task(self._recv(websocket, recv_queue, send_queue))
await asyncio.gather(send_task)
await asyncio.sleep(3)
await asyncio.gather(recv_task)
if self.send_time:
break
else:
self.fail_count += 1
logger.info(f"{self.idx} 条数据,初始化阶段, 第 {self.fail_count} 次失败, 1s 后重试")
time.sleep(1)
def action(self):
asyncio.run(self._action())
return self._gen_result()
def _gen_result(self) -> ASRContext:
if not self.predict_data:
logger.error(f"{self.idx} 条数据,没有任何数据返回")
self.context.append_preds(self.predict_data, self.send_time, self.recv_time)
self.context.fail = not self.predict_data
punctuation_words: List[StreamWordsModel] = []
for pred in self.predict_data:
punctuations = [",", ".", "!", "?"]
if pred.language == "zh":
punctuations = ["", "", "", ""]
elif pred.language == "ja":
punctuations = ["", "", "", ""]
elif pred.language in ("ar", "fa"):
punctuations = ["،", ".", "!", "؟"]
elif pred.language == "el":
punctuations = [",", ".", "", ""]
elif pred.language == "ti":
punctuations = [""]
for word in pred.words:
if word.text in punctuations:
punctuation_words.append(word)
start_times = list(map(lambda x: x.start_time, punctuation_words))
start_times = sorted(start_times)
end_times = list(map(lambda x: x.end_time, punctuation_words))
end_times = sorted(end_times)
self.context.punctuation_num = len(self.context.labels)
label_n = len(self.context.labels)
for i, label in enumerate(self.context.labels):
label_left = (label.end - 0.7)
label_right = (label.end + 0.7)
if i < label_n - 1:
label_left = label.end
label_right = self.context.labels[i+1].start
exist = False
def upper_bound(x: float, lst: List[float]) -> int:
ans = -1
left, right = 0, len(lst) - 1
while left <= right:
mid = (left + right) // 2
if lst[mid] >= x:
ans = mid
right = mid - 1
else:
left = mid + 1
return ans
def lower_bound(x: float, lst: List[float]) -> int:
ans = -1
left, right = 0, len(lst) - 1
while left <= right:
mid = (left + right) // 2
if lst[mid] <= x:
ans = mid
left = mid + 1
else:
right = mid - 1
return ans
left_in_pred = upper_bound(label_left, start_times)
if left_in_pred != -1 and start_times[left_in_pred] <= label_right:
exist = True
right_in_pred = lower_bound(label_right, end_times)
if right_in_pred != -1 and end_times[right_in_pred] >= label_left:
exist = True
if exist:
self.context.pred_punctuation_num += 1
return self.context
def _gen_init_data(self) -> dict:
return {
"parameter": {
"lang": None,
"sample_rate": self.context.sample_rate,
"channel": self.context.channel,
"format": self.context.audio_format,
"bits": self.context.bits,
"enable_words": self.context.enable_words,
}
}