import asyncio import json import os import time import traceback from copy import deepcopy from enum import Enum from typing import Any, List import websockets from pydantic_core import ValidationError from schemas.context import ASRContext from schemas.stream import StreamResultModel, StreamWordsModel from utils.logger import logger IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", None) is None class STATUS_DATA(str, Enum): WAITING_FIRST_INIT = "waiting_first_init" FIRST_FAIL = "fail" WAITING_SECOND_INIT = "waiting_second_init" SECOND_INIT = "second_fail" WAITING_THIRD_INIT = "waiting_third_init" THIRD_INIT = "third_fail" SUCCESS = "success" CLOSED = "closed" class ClientAsync: def __init__(self, sut_url: str, context: ASRContext, idx: int) -> None: # base_url = "ws://127.0.0.1:5003" self.base_url = sut_url + "/recognition" self.context: ASRContext = deepcopy(context) self.idx = idx # if not os.getenv("DATASET_FILEPATH", ""): # self.base_url = "wss://speech.4paradigm.com/aibuds/api/v1/recognition" # self.base_url = "ws://localhost:5003/recognition" self.fail_count = 0 self.close_time = 10**50 self.send_time: List[float] = [] self.recv_time: List[float] = [] self.predict_data: List[Any] = [] async def _sender( self, websocket: websockets.WebSocketClientProtocol, send_queue: asyncio.Queue, recv_queue: asyncio.Queue ): # 设置 websocket 缓冲区大小 websocket.transport.set_write_buffer_limits(1024 * 1024 * 1024) # 发送初始化数据 await websocket.send(json.dumps(self._gen_init_data())) await send_queue.put(STATUS_DATA.WAITING_FIRST_INIT) connect_status = await recv_queue.get() if connect_status == STATUS_DATA.FIRST_FAIL: return # 开始发送音频 with open(self.context.file_path, "rb") as fp: wav_data = fp.read() meta_length = wav_data.index(b"data") + 8 try: with open(self.context.file_path, "rb") as fp: # 去掉 wav 文件的头信息 fp.read(meta_length) wav_time = 0.0 label_id = 0 char_contains_rate_checktime = [] char_contains_rate_checktime_id = 0 while True: now_time = time.perf_counter() chunk = fp.read(int(self.context.chunk_size)) if not chunk: break wav_time += self.context.wait_time try: self.send_time.append(time.perf_counter()) await asyncio.wait_for(websocket.send(chunk), timeout=0.08) except asyncio.exceptions.TimeoutError: pass while label_id < len(self.context.labels) and wav_time >= self.context.labels[label_id].start: char_contains_rate_checktime.append(now_time + 3.0) label_id += 1 predict_text_len = sum(map(lambda x: len(x.text), self.predict_data)) while char_contains_rate_checktime_id < len(char_contains_rate_checktime) and \ char_contains_rate_checktime[char_contains_rate_checktime_id] <= now_time: label_text_len = sum( map(lambda x: len(x.answer), self.context.labels[:char_contains_rate_checktime_id+1])) if predict_text_len / self.context.char_contains_rate < label_text_len: self.context.fail_char_contains_rate_num += 1 char_contains_rate_checktime_id += 1 await asyncio.sleep(max(0, self.context.wait_time - (time.perf_counter() - now_time))) await websocket.send(json.dumps({"end": True})) logger.info(f"第 {self.idx} 条数据,当条语音数据发送完成") logger.info(f"第 {self.idx} 条数据,3s 后关闭双向连接.") self.close_time = time.perf_counter() + 3 except websockets.exceptions.ConnectionClosedError: logger.error(f"第 {self.idx} 条数据发送过程中,连接断开") except Exception: logger.error(f"{traceback.print_exc()}") logger.error(f"第 {self.idx} 条数据,发送数据失败") async def _recv( self, websocket: websockets.WebSocketClientProtocol, send_queue: asyncio.Queue, recv_queue: asyncio.Queue ): await recv_queue.get() try: await asyncio.wait_for(websocket.recv(), timeout=2) except asyncio.exceptions.TimeoutError: await send_queue.put(STATUS_DATA.FIRST_FAIL) logger.info(f"第 {self.idx} 条数据,初始化阶段, 2s 没收到 success 返回,超时了") self.fail_count += 1 return except Exception as e: await send_queue.put(STATUS_DATA.FIRST_FAIL) logger.error(f"第 {self.idx} 条数据,初始化阶段, 收到异常:{e}") self.fail_count += 1 return else: await send_queue.put(STATUS_DATA.SUCCESS) # 开始接收语音识别结果 try: while websocket.open: # 接收数据 recv_data = await websocket.recv() if isinstance(recv_data, str): self.recv_time.append(time.perf_counter()) recv_data = str(recv_data) recv_data = json.loads(recv_data) result = StreamResultModel(**recv_data) recognition_results = result.asr_results if ( recognition_results.final_result and not recognition_results.language and recognition_results.start_time == 0 and recognition_results.end_time == 0 and recognition_results.para_seq == 0 ): pass else: self.predict_data.append(recognition_results) else: raise Exception("返回的结果不是字符串形式") except websockets.exceptions.ConnectionClosedOK: pass except websockets.exceptions.ConnectionClosedError: pass except ValidationError as e: logger.error(f"第 {self.idx} 条数据,返回的结果不符合格式") logger.error(f"Exception is {e}") os._exit(1) except OSError: pass except Exception: logger.error(f"{traceback.print_exc()}") logger.error(f"第 {self.idx} 条数据,处理被测服务返回数据时出错") async def _action(self): logger.info(f"第 {self.idx} 条数据开始测试") while self.fail_count < 3: send_queue = asyncio.Queue() recv_queue = asyncio.Queue() self.send_time: List[float] = [] self.recv_time: List[float] = [] self.predict_data: List[Any] = [] async with websockets.connect(self.base_url) as websocket: send_task = asyncio.create_task(self._sender(websocket, send_queue, recv_queue)) recv_task = asyncio.create_task(self._recv(websocket, recv_queue, send_queue)) await asyncio.gather(send_task) await asyncio.sleep(3) await asyncio.gather(recv_task) if self.send_time: break else: self.fail_count += 1 logger.info(f"第 {self.idx} 条数据,初始化阶段, 第 {self.fail_count} 次失败, 1s 后重试") time.sleep(1) def action(self): asyncio.run(self._action()) return self._gen_result() def _gen_result(self) -> ASRContext: if not self.predict_data: logger.error(f"第 {self.idx} 条数据,没有任何数据返回") self.context.append_preds(self.predict_data, self.send_time, self.recv_time) self.context.fail = not self.predict_data punctuation_words: List[StreamWordsModel] = [] for pred in self.predict_data: punctuations = [",", ".", "!", "?"] if pred.language == "zh": punctuations = [",", "。", "!", "?"] elif pred.language == "ja": punctuations = ["、", "。", "!", "?"] elif pred.language in ("ar", "fa"): punctuations = ["،", ".", "!", "؟"] elif pred.language == "el": punctuations = [",", ".", "!", ";"] elif pred.language == "ti": punctuations = ["།"] for word in pred.words: if word.text in punctuations: punctuation_words.append(word) start_times = list(map(lambda x: x.start_time, punctuation_words)) start_times = sorted(start_times) end_times = list(map(lambda x: x.end_time, punctuation_words)) end_times = sorted(end_times) self.context.punctuation_num = len(self.context.labels) label_n = len(self.context.labels) for i, label in enumerate(self.context.labels): label_left = (label.end - 0.7) label_right = (label.end + 0.7) if i < label_n - 1: label_left = label.end label_right = self.context.labels[i+1].start exist = False def upper_bound(x: float, lst: List[float]) -> int: ans = -1 left, right = 0, len(lst) - 1 while left <= right: mid = (left + right) // 2 if lst[mid] >= x: ans = mid right = mid - 1 else: left = mid + 1 return ans def lower_bound(x: float, lst: List[float]) -> int: ans = -1 left, right = 0, len(lst) - 1 while left <= right: mid = (left + right) // 2 if lst[mid] <= x: ans = mid left = mid + 1 else: right = mid - 1 return ans left_in_pred = upper_bound(label_left, start_times) if left_in_pred != -1 and start_times[left_in_pred] <= label_right: exist = True right_in_pred = lower_bound(label_right, end_times) if right_in_pred != -1 and end_times[right_in_pred] >= label_left: exist = True if exist: self.context.pred_punctuation_num += 1 return self.context def _gen_init_data(self) -> dict: return { "parameter": { "lang": None, "sample_rate": self.context.sample_rate, "channel": self.context.channel, "format": self.context.audio_format, "bits": self.context.bits, "enable_words": self.context.enable_words, } }