Files
zhousha 55a67e817e update
2025-08-06 15:38:55 +08:00

225 lines
9.2 KiB
Python

import json
import os
import threading
import time
import traceback
from copy import deepcopy
from typing import Any, List
import websocket
from pydantic_core import ValidationError
from websocket import create_connection
from schemas.context import ASRContext
from schemas.stream import StreamDataModel, StreamResultModel
from utils.logger import logger
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", None) is None
class Client:
def __init__(self, sut_url: str, context: ASRContext) -> None:
# base_url = "ws://127.0.0.1:5003"
self.base_url = sut_url + "/recognition"
logger.info(f"{self.base_url}")
self.context: ASRContext = deepcopy(context)
# if not os.getenv("DATASET_FILEPATH", ""):
# self.base_url = "wss://speech.4paradigm.com/aibuds/api/v1/recognition"
# self.base_url = "ws://localhost:5003/recognition"
self.connect_num = 0
self.exception = False
self.close_time = 10**50
self.send_time: List[float] = []
self.recv_time: List[float] = []
self.predict_data: List[Any] = []
self.success = True
def action(self):
# 如果 5 次初始化都失败,则退出
connect_success = False
for i in range(5):
try:
self._connect_init()
connect_success = True
break
except Exception as e:
logger.error(f"{i+1} 次连接失败,原因:{e}")
time.sleep(int(os.getenv("connect_sleep", 10)))
if not connect_success:
exit(-1)
self.trecv = threading.Thread(target=self._recv)
self.trecv.start()
self._send()
self._close()
return self._gen_result()
def _connect_init(self):
end_time = time.time() + float(os.getenv("end_time", 2))
success = False
try:
self.ws = create_connection(self.base_url)
self.ws.send(json.dumps(self._gen_init_data()))
while time.time() < end_time and not success:
data = self.ws.recv()
logger.info(f"data {data}")
if len(data) == 0:
time.sleep(1)
continue
if isinstance(data, str):
try:
data = json.loads(data)
except Exception:
raise Exception("初始化阶段,数据不是 json 字符串格式,终止流程")
if isinstance(data, dict):
success = data.get("success", False)
if not success:
logger.error(f"初始化失败,返回的结果为 {data},终止流程")
else:
break
logger.error("初始化阶段,数据不是 json 字符串格式,终止流程")
exit(-1)
except websocket.WebSocketConnectionClosedException or TimeoutError:
raise Exception("初始化阶段连接中断,终止流程")
# exit(-1)
except ConnectionRefusedError:
raise Exception("初始化阶段,连接失败,等待 10s 后重试,最多重试 5 次")
# logger.error("初始化阶段,连接失败,等待 10s 后重试,最多重试 5 次")
# self.connect_num += 1
# if self.connect_num <= 4:
# time.sleep(int(os.getenv("connect_sleep", 10)))
# self._connect_init()
# success = True
# else:
# logger.error("初始化阶段连接失败多次")
# exit(-1)
if not success:
# logger.error("初始化阶段 60s 没有返回数据,时间太长,终止流程")
raise Exception("初始化阶段 60s 没有返回数据,时间太长,终止流程")
else:
logger.info("建立连接成功")
self.connect_num = 0
def _send(self):
send_ts = float(os.getenv("send_interval", 60))
if not self.success:
return
with open(self.context.file_path, "rb") as fp:
wav_data = fp.read()
meta_length = wav_data.index(b"data") + 8
try:
with open(self.context.file_path, "rb") as fp:
# 去掉 wav 文件的头信息
fp.read(meta_length)
# 上一段音频的发送时间
last_send_time = -1
# 正文内容
while True:
now_time = time.perf_counter()
if last_send_time == -1:
chunk = fp.read(int(self.context.chunk_size))
else:
interval_cnt = max(
int((now_time - last_send_time) / self.context.wait_time),
1,
)
chunk = fp.read(int(self.context.chunk_size * interval_cnt))
if not chunk:
break
send_time_start = time.perf_counter()
self.ws.send(chunk, websocket.ABNF.OPCODE_BINARY)
self.send_time.append(send_time_start)
last_send_time = send_time_start
send_time_end = time.perf_counter()
if send_time_end - send_time_start > send_ts:
logger.error(f"发送延迟已经超过 {send_ts}s, 终止当前音频发送")
break
if (sleep_time := self.context.wait_time + now_time - send_time_end) > 0:
time.sleep(sleep_time)
logger.info("当条语音数据发送完成")
self.ws.send(json.dumps({"end": True}))
logger.info("2s 后关闭双向连接.")
except BrokenPipeError:
logger.error("发送数据出错,被测服务出现故障")
except Exception as e:
logger.error(f"Exception: {e}")
logger.error(f"{traceback.print_exc()}")
logger.error("发送数据失败")
self.success = False
# self.close_time = time.perf_counter() + int(os.getenv("api_timeout", 2))
self.close_time = time.perf_counter() + 20 * 60
def _recv(self):
try:
while self.ws.connected and self.success:
recv_data = self.ws.recv()
if isinstance(recv_data, str):
if recv_data := str(recv_data):
self.recv_time.append(time.perf_counter())
# 识别到最后的合并结果后再关闭
recognition_results = StreamResultModel(**json.loads(recv_data)).recognition_results
if (
recognition_results.final_result
and recognition_results.start_time == 0
and recognition_results.end_time == 0
and recognition_results.para_seq == 0
):
self.success = False
else:
self.predict_data.append(recv_data)
# if recv_data.recognition_results.final_result and (IN_TEST or os.getenv('test')):
# logger.info(f"recv_data {recv_data}")
else:
self.success = False
raise Exception("返回的结果不是字符串形式")
except websocket.WebSocketConnectionClosedException:
logger.error("WebSocketConnectionClosedException")
except ValidationError as e:
logger.error("返回的结果不符合格式")
logger.error(f"Exception is {e}")
os._exit(1)
except OSError:
pass
except Exception:
logger.error(f"{traceback.print_exc()}")
logger.error("处理被测服务返回数据时出错")
self.success = False
def _close(self):
while time.perf_counter() < self.close_time and self.success:
# while not self.success:
time.sleep(1)
try:
self.ws.close()
except Exception as e:
print(e)
pass
def _gen_result(self) -> dict:
if not self.predict_data:
logger.error("没有任何数据返回")
self.predict_data = [StreamResultModel(**json.loads(data)).recognition_results for data in self.predict_data]
# for item in self.predict_data:
# if item.final_result and (IN_TEST or os.getenv('test')):
# logger.info(f"recv_data {item}")
return {
"fail": not self.predict_data,
"send_time": self.send_time,
"recv_time": self.recv_time,
"predict_data": self.predict_data,
}
def _gen_init_data(self) -> dict:
return {
"parameter": {
"lang": self.context.lang,
"sample_rate": self.context.sample_rate,
"channel": self.context.channel,
"format": self.context.audio_format,
"bits": self.context.bits,
"enable_words": self.context.enable_words,
}
}