225 lines
9.2 KiB
Python
225 lines
9.2 KiB
Python
import json
|
|
import os
|
|
import threading
|
|
import time
|
|
import traceback
|
|
from copy import deepcopy
|
|
from typing import Any, List
|
|
|
|
import websocket
|
|
from pydantic_core import ValidationError
|
|
from websocket import create_connection
|
|
|
|
from schemas.context import ASRContext
|
|
from schemas.stream import StreamDataModel, StreamResultModel
|
|
from utils.logger import logger
|
|
|
|
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", None) is None
|
|
|
|
|
|
class Client:
|
|
def __init__(self, sut_url: str, context: ASRContext) -> None:
|
|
# base_url = "ws://127.0.0.1:5003"
|
|
self.base_url = sut_url + "/recognition"
|
|
logger.info(f"{self.base_url}")
|
|
self.context: ASRContext = deepcopy(context)
|
|
# if not os.getenv("DATASET_FILEPATH", ""):
|
|
# self.base_url = "wss://speech.4paradigm.com/aibuds/api/v1/recognition"
|
|
# self.base_url = "ws://localhost:5003/recognition"
|
|
self.connect_num = 0
|
|
self.exception = False
|
|
self.close_time = 10**50
|
|
self.send_time: List[float] = []
|
|
self.recv_time: List[float] = []
|
|
self.predict_data: List[Any] = []
|
|
self.success = True
|
|
|
|
def action(self):
|
|
# 如果 5 次初始化都失败,则退出
|
|
connect_success = False
|
|
for i in range(5):
|
|
try:
|
|
self._connect_init()
|
|
connect_success = True
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"第 {i+1} 次连接失败,原因:{e}")
|
|
time.sleep(int(os.getenv("connect_sleep", 10)))
|
|
if not connect_success:
|
|
exit(-1)
|
|
self.trecv = threading.Thread(target=self._recv)
|
|
self.trecv.start()
|
|
self._send()
|
|
self._close()
|
|
return self._gen_result()
|
|
|
|
def _connect_init(self):
|
|
end_time = time.time() + float(os.getenv("end_time", 2))
|
|
success = False
|
|
try:
|
|
self.ws = create_connection(self.base_url)
|
|
self.ws.send(json.dumps(self._gen_init_data()))
|
|
while time.time() < end_time and not success:
|
|
data = self.ws.recv()
|
|
logger.info(f"data {data}")
|
|
if len(data) == 0:
|
|
time.sleep(1)
|
|
continue
|
|
if isinstance(data, str):
|
|
try:
|
|
data = json.loads(data)
|
|
except Exception:
|
|
raise Exception("初始化阶段,数据不是 json 字符串格式,终止流程")
|
|
if isinstance(data, dict):
|
|
success = data.get("success", False)
|
|
if not success:
|
|
logger.error(f"初始化失败,返回的结果为 {data},终止流程")
|
|
else:
|
|
break
|
|
logger.error("初始化阶段,数据不是 json 字符串格式,终止流程")
|
|
exit(-1)
|
|
except websocket.WebSocketConnectionClosedException or TimeoutError:
|
|
raise Exception("初始化阶段连接中断,终止流程")
|
|
# exit(-1)
|
|
except ConnectionRefusedError:
|
|
raise Exception("初始化阶段,连接失败,等待 10s 后重试,最多重试 5 次")
|
|
# logger.error("初始化阶段,连接失败,等待 10s 后重试,最多重试 5 次")
|
|
# self.connect_num += 1
|
|
# if self.connect_num <= 4:
|
|
# time.sleep(int(os.getenv("connect_sleep", 10)))
|
|
# self._connect_init()
|
|
# success = True
|
|
# else:
|
|
# logger.error("初始化阶段连接失败多次")
|
|
# exit(-1)
|
|
if not success:
|
|
# logger.error("初始化阶段 60s 没有返回数据,时间太长,终止流程")
|
|
raise Exception("初始化阶段 60s 没有返回数据,时间太长,终止流程")
|
|
else:
|
|
logger.info("建立连接成功")
|
|
self.connect_num = 0
|
|
|
|
def _send(self):
|
|
send_ts = float(os.getenv("send_interval", 60))
|
|
if not self.success:
|
|
return
|
|
|
|
with open(self.context.file_path, "rb") as fp:
|
|
wav_data = fp.read()
|
|
meta_length = wav_data.index(b"data") + 8
|
|
|
|
try:
|
|
with open(self.context.file_path, "rb") as fp:
|
|
# 去掉 wav 文件的头信息
|
|
fp.read(meta_length)
|
|
# 上一段音频的发送时间
|
|
last_send_time = -1
|
|
# 正文内容
|
|
while True:
|
|
now_time = time.perf_counter()
|
|
if last_send_time == -1:
|
|
chunk = fp.read(int(self.context.chunk_size))
|
|
else:
|
|
interval_cnt = max(
|
|
int((now_time - last_send_time) / self.context.wait_time),
|
|
1,
|
|
)
|
|
chunk = fp.read(int(self.context.chunk_size * interval_cnt))
|
|
if not chunk:
|
|
break
|
|
send_time_start = time.perf_counter()
|
|
self.ws.send(chunk, websocket.ABNF.OPCODE_BINARY)
|
|
self.send_time.append(send_time_start)
|
|
last_send_time = send_time_start
|
|
send_time_end = time.perf_counter()
|
|
if send_time_end - send_time_start > send_ts:
|
|
logger.error(f"发送延迟已经超过 {send_ts}s, 终止当前音频发送")
|
|
break
|
|
if (sleep_time := self.context.wait_time + now_time - send_time_end) > 0:
|
|
time.sleep(sleep_time)
|
|
logger.info("当条语音数据发送完成")
|
|
self.ws.send(json.dumps({"end": True}))
|
|
logger.info("2s 后关闭双向连接.")
|
|
except BrokenPipeError:
|
|
logger.error("发送数据出错,被测服务出现故障")
|
|
except Exception as e:
|
|
logger.error(f"Exception: {e}")
|
|
logger.error(f"{traceback.print_exc()}")
|
|
logger.error("发送数据失败")
|
|
self.success = False
|
|
# self.close_time = time.perf_counter() + int(os.getenv("api_timeout", 2))
|
|
self.close_time = time.perf_counter() + 20 * 60
|
|
|
|
def _recv(self):
|
|
try:
|
|
while self.ws.connected and self.success:
|
|
recv_data = self.ws.recv()
|
|
if isinstance(recv_data, str):
|
|
if recv_data := str(recv_data):
|
|
self.recv_time.append(time.perf_counter())
|
|
# 识别到最后的合并结果后再关闭
|
|
recognition_results = StreamResultModel(**json.loads(recv_data)).recognition_results
|
|
if (
|
|
recognition_results.final_result
|
|
and recognition_results.start_time == 0
|
|
and recognition_results.end_time == 0
|
|
and recognition_results.para_seq == 0
|
|
):
|
|
self.success = False
|
|
else:
|
|
self.predict_data.append(recv_data)
|
|
# if recv_data.recognition_results.final_result and (IN_TEST or os.getenv('test')):
|
|
# logger.info(f"recv_data {recv_data}")
|
|
else:
|
|
self.success = False
|
|
raise Exception("返回的结果不是字符串形式")
|
|
except websocket.WebSocketConnectionClosedException:
|
|
logger.error("WebSocketConnectionClosedException")
|
|
except ValidationError as e:
|
|
logger.error("返回的结果不符合格式")
|
|
logger.error(f"Exception is {e}")
|
|
os._exit(1)
|
|
except OSError:
|
|
pass
|
|
except Exception:
|
|
logger.error(f"{traceback.print_exc()}")
|
|
logger.error("处理被测服务返回数据时出错")
|
|
self.success = False
|
|
|
|
def _close(self):
|
|
while time.perf_counter() < self.close_time and self.success:
|
|
# while not self.success:
|
|
time.sleep(1)
|
|
try:
|
|
self.ws.close()
|
|
except Exception as e:
|
|
print(e)
|
|
pass
|
|
|
|
def _gen_result(self) -> dict:
|
|
if not self.predict_data:
|
|
logger.error("没有任何数据返回")
|
|
self.predict_data = [StreamResultModel(**json.loads(data)).recognition_results for data in self.predict_data]
|
|
# for item in self.predict_data:
|
|
# if item.final_result and (IN_TEST or os.getenv('test')):
|
|
# logger.info(f"recv_data {item}")
|
|
|
|
return {
|
|
"fail": not self.predict_data,
|
|
"send_time": self.send_time,
|
|
"recv_time": self.recv_time,
|
|
"predict_data": self.predict_data,
|
|
}
|
|
|
|
def _gen_init_data(self) -> dict:
|
|
return {
|
|
"parameter": {
|
|
"lang": self.context.lang,
|
|
"sample_rate": self.context.sample_rate,
|
|
"channel": self.context.channel,
|
|
"format": self.context.audio_format,
|
|
"bits": self.context.bits,
|
|
"enable_words": self.context.enable_words,
|
|
}
|
|
}
|