import json, os, threading, time, traceback from typing import ( Any, List ) from copy import deepcopy from utils.logger import logger from websocket import ( create_connection, WebSocketConnectionClosedException, ABNF ) from utils.model import ( ASRResponseModel, SegmentModel ) import threading import queue from pydantic_core import ValidationError _IS_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH") is None class ASRWebSocketClient: def __init__(self, url: str): self.endpoint = f"{url}/recognition" # self.ctx = deepcopy(ctx) self.ws = None self.conn_attempts = -1 self.failed = False self.terminate_time = float("inf") self.sent_timestamps: List[float] = [] self.received_timestamps: List[float] = [] self.results: List[Any] = [] self.connected = True logger.info(f"Target endpoint: {self.endpoint}") def execute(self, path: str) -> List[SegmentModel]: # 开启线程,一个发,一个接要记录接收数据的时间。 send_thread = threading.Thread(target=self.send, args=(path,)) # 记录开始时间 start_time = time.time() # 启动线程发送数据 send_thread.start() # 用来返回结果的队列 result_queue = queue.Queue() # 线程封装 receive def receive_thread_fn(): try: res = self.receive(start_time) result_queue.put(res) except Exception as e: # 放异常信息 result_queue.put(None) # 启动 receive 线程(关注返回值) receive_thread = threading.Thread(target=receive_thread_fn) receive_thread.start() # 可选:等待 send 线程也结束(可删) send_thread.join() receive_thread.join() # 等待 receive 完成,并获取返回值 result = result_queue.get() return result def initialize_connection(self, language: str): expiration = time.time() + float(os.getenv("end_time", "2")) self.connected = False init = False while True: try: if not init: logger.debug(f"建立ws链接:发送建立ws链接请求。") self.ws = create_connection(self.endpoint) body = json.dumps(self._get_init_payload(language)) logger.debug(f"建立ws链接:发送初始化数据。{body}") self.ws.send(body) init = True msg = self.ws.recv() logger.debug(f"收到响应数据: {msg}") if len(msg) == 0: time.sleep(0.5) # 睡眠一下等待数据写回来 continue if isinstance(msg, str): try: msg = json.loads(msg) except Exception: raise Exception("建立ws链接:响应数据非json格式!") if isinstance(msg, dict): connected = msg.get("success") if connected: logger.debug("建立ws链接:链接建立成功!") self.conn_attempts = 0 self.connected = True return else: logger.info("建立ws链接:链接建立失败!") init = False self.conn_attempts = self.conn_attempts + 1 if self.conn_attempts > 5: raise ConnectionRefusedError("重试5次后,仍然无法建立ws链接。") if time.time() > expiration: raise RuntimeError("建立ws链接:链接建立超时!") except WebSocketConnectionClosedException or TimeoutError: raise Exception("建立ws链接:初始化阶段连接中断,退出。") except Exception as e: logger.info("建立ws链接:链接建立失败!") init = False self.conn_attempts = self.conn_attempts + 1 if self.conn_attempts > 5: raise ConnectionRefusedError("重试5次后,仍然无法建立ws链接。") if time.time() > expiration: raise RuntimeError("建立ws链接:链接建立超时!") def shutdown(self): try: if self.ws: self.ws.close() except Exception: pass self.connected = False def _get_init_payload(self, language="zh") -> dict: # language = "zh" return { "language": language } def _get_finish_payload(self) -> dict: return { "end": "true" } def send(self, path): skip_wav_header = path.endswith("wav") with open(path, "rb") as f: if skip_wav_header: # WAV 文件头部为 44 字节 f.read(44) while chunk := f.read(3200): logger.debug(f"发送 {len(chunk)} 字节数据.") self.ws.send(chunk, opcode=ABNF.OPCODE_BINARY) time.sleep(0.1) self.ws.send(json.dumps(self._get_finish_payload())) def receive(self, start_time) -> [SegmentModel]: results = [] while True: msg = self.ws.recv() # 记录读取到数据的时间(毫秒值) now = 1000 * (time.time() - start_time) logger.debug(f"{now} 收到响应数据: {msg}") res = json.loads(msg) if res.get("asr_results"): item = ASRResponseModel.model_validate_json(msg).asr_results item.receive_time = now results.append(item) # logger.info(item.summary()) else: logger.debug(f"响应结束") # 按照para_seq排序,并检查一下序号是否连续 results.sort(key=lambda x: x.para_seq) missing_seqs = [] for i in range(len(results) - 1): expected_next = results[i].para_seq + 1 actual_next = results[i + 1].para_seq if actual_next != expected_next: missing_seqs.extend(range(expected_next, actual_next)) if missing_seqs: logger.warning(f"检测到丢失的段落序号:{missing_seqs}") else: logger.debug("响应数据正常") return results if __name__ == '__main__': ws_client = ASRWebSocketClient("ws://localhost:18000") ws_client.initialize_connection("zh") res = ws_client.execute("/Users/yu/Documents/code-work/asr-live-iluvatar/zh_250312/zh/99.wav") for i in res: print(i.summary()) print()