Files
enginex-c_series-asr/utils/client.py

196 lines
6.9 KiB
Python
Raw Permalink Normal View History

2025-08-28 18:46:56 +08:00
import json, os, threading, time, traceback
from typing import (
Any, List
)
from copy import deepcopy
from utils.logger import logger
from websocket import (
create_connection,
WebSocketConnectionClosedException,
ABNF
)
from utils.model import (
ASRResponseModel,
SegmentModel
)
import threading
import queue
from pydantic_core import ValidationError
_IS_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH") is None
class ASRWebSocketClient:
def __init__(self, url: str):
self.endpoint = f"{url}/recognition"
# self.ctx = deepcopy(ctx)
self.ws = None
self.conn_attempts = -1
self.failed = False
self.terminate_time = float("inf")
self.sent_timestamps: List[float] = []
self.received_timestamps: List[float] = []
self.results: List[Any] = []
self.connected = True
logger.info(f"Target endpoint: {self.endpoint}")
def execute(self, path: str) -> List[SegmentModel]:
# 开启线程,一个发,一个接要记录接收数据的时间。
send_thread = threading.Thread(target=self.send, args=(path,))
# 记录开始时间
start_time = time.time()
# 启动线程发送数据
send_thread.start()
# 用来返回结果的队列
result_queue = queue.Queue()
# 线程封装 receive
def receive_thread_fn():
try:
res = self.receive(start_time)
result_queue.put(res)
except Exception as e:
# 放异常信息
result_queue.put(None)
# 启动 receive 线程(关注返回值)
receive_thread = threading.Thread(target=receive_thread_fn)
receive_thread.start()
# 可选:等待 send 线程也结束(可删)
send_thread.join()
receive_thread.join()
# 等待 receive 完成,并获取返回值
result = result_queue.get()
return result
def initialize_connection(self, language: str):
expiration = time.time() + float(os.getenv("end_time", "2"))
self.connected = False
init = False
while True:
try:
if not init:
logger.debug(f"建立ws链接发送建立ws链接请求。")
self.ws = create_connection(self.endpoint)
body = json.dumps(self._get_init_payload(language))
logger.debug(f"建立ws链接发送初始化数据。{body}")
self.ws.send(body)
init = True
msg = self.ws.recv()
logger.debug(f"收到响应数据: {msg}")
if len(msg) == 0:
time.sleep(0.5) # 睡眠一下等待数据写回来
continue
if isinstance(msg, str):
try:
msg = json.loads(msg)
except Exception:
raise Exception("建立ws链接响应数据非json格式")
if isinstance(msg, dict):
connected = msg.get("success")
if connected:
logger.debug("建立ws链接链接建立成功")
self.conn_attempts = 0
self.connected = True
return
else:
logger.info("建立ws链接链接建立失败")
init = False
self.conn_attempts = self.conn_attempts + 1
if self.conn_attempts > 5:
raise ConnectionRefusedError("重试5次后仍然无法建立ws链接。")
if time.time() > expiration:
raise RuntimeError("建立ws链接链接建立超时")
except WebSocketConnectionClosedException or TimeoutError:
raise Exception("建立ws链接初始化阶段连接中断退出。")
except Exception as e:
logger.info("建立ws链接链接建立失败")
init = False
self.conn_attempts = self.conn_attempts + 1
if self.conn_attempts > 5:
raise ConnectionRefusedError("重试5次后仍然无法建立ws链接。")
if time.time() > expiration:
raise RuntimeError("建立ws链接链接建立超时")
def shutdown(self):
try:
if self.ws:
self.ws.close()
except Exception:
pass
self.connected = False
def _get_init_payload(self, language="zh") -> dict:
# language = "zh"
return {
"language": language
}
def _get_finish_payload(self) -> dict:
return {
"end": "true"
}
def send(self, path):
skip_wav_header = path.endswith("wav")
with open(path, "rb") as f:
if skip_wav_header:
# WAV 文件头部为 44 字节
f.read(44)
while chunk := f.read(3200):
logger.debug(f"发送 {len(chunk)} 字节数据.")
self.ws.send(chunk, opcode=ABNF.OPCODE_BINARY)
time.sleep(0.1)
self.ws.send(json.dumps(self._get_finish_payload()))
def receive(self, start_time) -> [SegmentModel]:
results = []
while True:
msg = self.ws.recv()
# 记录读取到数据的时间(毫秒值)
now = 1000 * (time.time() - start_time)
logger.debug(f"{now} 收到响应数据: {msg}")
res = json.loads(msg)
if res.get("asr_results"):
item = ASRResponseModel.model_validate_json(msg).asr_results
item.receive_time = now
results.append(item)
# logger.info(item.summary())
else:
logger.debug(f"响应结束")
# 按照para_seq排序并检查一下序号是否连续
results.sort(key=lambda x: x.para_seq)
missing_seqs = []
for i in range(len(results) - 1):
expected_next = results[i].para_seq + 1
actual_next = results[i + 1].para_seq
if actual_next != expected_next:
missing_seqs.extend(range(expected_next, actual_next))
if missing_seqs:
logger.warning(f"检测到丢失的段落序号:{missing_seqs}")
else:
logger.debug("响应数据正常")
return results
if __name__ == '__main__':
ws_client = ASRWebSocketClient("ws://localhost:18000")
ws_client.initialize_connection("zh")
res = ws_client.execute("/Users/yu/Documents/code-work/asr-live-iluvatar/zh_250312/zh/99.wav")
for i in res:
print(i.summary())
print()