import json import os import threading import time import traceback from copy import deepcopy from typing import Any, List import websocket from pydantic_core import ValidationError from websocket import create_connection from schemas.context import ASRContext from schemas.stream import StreamDataModel, StreamResultModel from utils.logger import logger IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", None) is None class Client: def __init__(self, sut_url: str, context: ASRContext) -> None: # base_url = "ws://127.0.0.1:5003" self.base_url = sut_url + "/recognition" logger.info(f"{self.base_url}") self.context: ASRContext = deepcopy(context) # if not os.getenv("DATASET_FILEPATH", ""): # self.base_url = "wss://speech.4paradigm.com/aibuds/api/v1/recognition" # self.base_url = "ws://localhost:5003/recognition" self.connect_num = 0 self.exception = False self.close_time = 10**50 self.send_time: List[float] = [] self.recv_time: List[float] = [] self.predict_data: List[Any] = [] self.success = True def action(self): # 如果 5 次初始化都失败,则退出 connect_success = False for i in range(5): try: self._connect_init() connect_success = True break except Exception as e: logger.error(f"第 {i+1} 次连接失败,原因:{e}") time.sleep(int(os.getenv("connect_sleep", 10))) if not connect_success: exit(-1) self.trecv = threading.Thread(target=self._recv) self.trecv.start() self._send() self._close() return self._gen_result() def _connect_init(self): end_time = time.time() + float(os.getenv("end_time", 2)) success = False try: self.ws = create_connection(self.base_url) self.ws.send(json.dumps(self._gen_init_data())) while time.time() < end_time and not success: data = self.ws.recv() logger.info(f"data {data}") if len(data) == 0: time.sleep(1) continue if isinstance(data, str): try: data = json.loads(data) except Exception: raise Exception("初始化阶段,数据不是 json 字符串格式,终止流程") if isinstance(data, dict): success = data.get("success", False) if not success: logger.error(f"初始化失败,返回的结果为 {data},终止流程") else: break logger.error("初始化阶段,数据不是 json 字符串格式,终止流程") exit(-1) except websocket.WebSocketConnectionClosedException or TimeoutError: raise Exception("初始化阶段连接中断,终止流程") # exit(-1) except ConnectionRefusedError: raise Exception("初始化阶段,连接失败,等待 10s 后重试,最多重试 5 次") # logger.error("初始化阶段,连接失败,等待 10s 后重试,最多重试 5 次") # self.connect_num += 1 # if self.connect_num <= 4: # time.sleep(int(os.getenv("connect_sleep", 10))) # self._connect_init() # success = True # else: # logger.error("初始化阶段连接失败多次") # exit(-1) if not success: # logger.error("初始化阶段 60s 没有返回数据,时间太长,终止流程") raise Exception("初始化阶段 60s 没有返回数据,时间太长,终止流程") else: logger.info("建立连接成功") self.connect_num = 0 def _send(self): send_ts = float(os.getenv("send_interval", 60)) if not self.success: return with open(self.context.file_path, "rb") as fp: wav_data = fp.read() meta_length = wav_data.index(b"data") + 8 try: with open(self.context.file_path, "rb") as fp: # 去掉 wav 文件的头信息 fp.read(meta_length) # 上一段音频的发送时间 last_send_time = -1 # 正文内容 while True: now_time = time.perf_counter() if last_send_time == -1: chunk = fp.read(int(self.context.chunk_size)) else: interval_cnt = max( int((now_time - last_send_time) / self.context.wait_time), 1, ) chunk = fp.read(int(self.context.chunk_size * interval_cnt)) if not chunk: break send_time_start = time.perf_counter() self.ws.send(chunk, websocket.ABNF.OPCODE_BINARY) self.send_time.append(send_time_start) last_send_time = send_time_start send_time_end = time.perf_counter() if send_time_end - send_time_start > send_ts: logger.error(f"发送延迟已经超过 {send_ts}s, 终止当前音频发送") break if (sleep_time := self.context.wait_time + now_time - send_time_end) > 0: time.sleep(sleep_time) logger.info("当条语音数据发送完成") self.ws.send(json.dumps({"end": True})) logger.info("2s 后关闭双向连接.") except BrokenPipeError: logger.error("发送数据出错,被测服务出现故障") except Exception as e: logger.error(f"Exception: {e}") logger.error(f"{traceback.print_exc()}") logger.error("发送数据失败") self.success = False # self.close_time = time.perf_counter() + int(os.getenv("api_timeout", 2)) self.close_time = time.perf_counter() + 20 * 60 def _recv(self): try: while self.ws.connected and self.success: recv_data = self.ws.recv() if isinstance(recv_data, str): if recv_data := str(recv_data): self.recv_time.append(time.perf_counter()) # 识别到最后的合并结果后再关闭 recognition_results = StreamResultModel(**json.loads(recv_data)).recognition_results if ( recognition_results.final_result and recognition_results.start_time == 0 and recognition_results.end_time == 0 and recognition_results.para_seq == 0 ): self.success = False else: self.predict_data.append(recv_data) # if recv_data.recognition_results.final_result and (IN_TEST or os.getenv('test')): # logger.info(f"recv_data {recv_data}") else: self.success = False raise Exception("返回的结果不是字符串形式") except websocket.WebSocketConnectionClosedException: logger.error("WebSocketConnectionClosedException") except ValidationError as e: logger.error("返回的结果不符合格式") logger.error(f"Exception is {e}") os._exit(1) except OSError: pass except Exception: logger.error(f"{traceback.print_exc()}") logger.error("处理被测服务返回数据时出错") self.success = False def _close(self): while time.perf_counter() < self.close_time and self.success: # while not self.success: time.sleep(1) try: self.ws.close() except Exception as e: print(e) pass def _gen_result(self) -> dict: if not self.predict_data: logger.error("没有任何数据返回") self.predict_data = [StreamResultModel(**json.loads(data)).recognition_results for data in self.predict_data] # for item in self.predict_data: # if item.final_result and (IN_TEST or os.getenv('test')): # logger.info(f"recv_data {item}") return { "fail": not self.predict_data, "send_time": self.send_time, "recv_time": self.recv_time, "predict_data": self.predict_data, } def _gen_init_data(self) -> dict: return { "parameter": { "lang": self.context.lang, "sample_rate": self.context.sample_rate, "channel": self.context.channel, "format": self.context.audio_format, "bits": self.context.bits, "enable_words": self.context.enable_words, } }