update
This commit is contained in:
224
utils/client.py
Normal file
224
utils/client.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from typing import Any, List
|
||||
|
||||
import websocket
|
||||
from pydantic_core import ValidationError
|
||||
from websocket import create_connection
|
||||
|
||||
from schemas.context import ASRContext
|
||||
from schemas.stream import StreamDataModel, StreamResultModel
|
||||
from utils.logger import logger
|
||||
|
||||
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", None) is None
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(self, sut_url: str, context: ASRContext) -> None:
|
||||
# base_url = "ws://127.0.0.1:5003"
|
||||
self.base_url = sut_url + "/recognition"
|
||||
logger.info(f"{self.base_url}")
|
||||
self.context: ASRContext = deepcopy(context)
|
||||
# if not os.getenv("DATASET_FILEPATH", ""):
|
||||
# self.base_url = "wss://speech.4paradigm.com/aibuds/api/v1/recognition"
|
||||
# self.base_url = "ws://localhost:5003/recognition"
|
||||
self.connect_num = 0
|
||||
self.exception = False
|
||||
self.close_time = 10**50
|
||||
self.send_time: List[float] = []
|
||||
self.recv_time: List[float] = []
|
||||
self.predict_data: List[Any] = []
|
||||
self.success = True
|
||||
|
||||
def action(self):
|
||||
# 如果 5 次初始化都失败,则退出
|
||||
connect_success = False
|
||||
for i in range(5):
|
||||
try:
|
||||
self._connect_init()
|
||||
connect_success = True
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"第 {i+1} 次连接失败,原因:{e}")
|
||||
time.sleep(int(os.getenv("connect_sleep", 10)))
|
||||
if not connect_success:
|
||||
exit(-1)
|
||||
self.trecv = threading.Thread(target=self._recv)
|
||||
self.trecv.start()
|
||||
self._send()
|
||||
self._close()
|
||||
return self._gen_result()
|
||||
|
||||
def _connect_init(self):
|
||||
end_time = time.time() + float(os.getenv("end_time", 2))
|
||||
success = False
|
||||
try:
|
||||
self.ws = create_connection(self.base_url)
|
||||
self.ws.send(json.dumps(self._gen_init_data()))
|
||||
while time.time() < end_time and not success:
|
||||
data = self.ws.recv()
|
||||
logger.info(f"data {data}")
|
||||
if len(data) == 0:
|
||||
time.sleep(1)
|
||||
continue
|
||||
if isinstance(data, str):
|
||||
try:
|
||||
data = json.loads(data)
|
||||
except Exception:
|
||||
raise Exception("初始化阶段,数据不是 json 字符串格式,终止流程")
|
||||
if isinstance(data, dict):
|
||||
success = data.get("success", False)
|
||||
if not success:
|
||||
logger.error(f"初始化失败,返回的结果为 {data},终止流程")
|
||||
else:
|
||||
break
|
||||
logger.error("初始化阶段,数据不是 json 字符串格式,终止流程")
|
||||
exit(-1)
|
||||
except websocket.WebSocketConnectionClosedException or TimeoutError:
|
||||
raise Exception("初始化阶段连接中断,终止流程")
|
||||
# exit(-1)
|
||||
except ConnectionRefusedError:
|
||||
raise Exception("初始化阶段,连接失败,等待 10s 后重试,最多重试 5 次")
|
||||
# logger.error("初始化阶段,连接失败,等待 10s 后重试,最多重试 5 次")
|
||||
# self.connect_num += 1
|
||||
# if self.connect_num <= 4:
|
||||
# time.sleep(int(os.getenv("connect_sleep", 10)))
|
||||
# self._connect_init()
|
||||
# success = True
|
||||
# else:
|
||||
# logger.error("初始化阶段连接失败多次")
|
||||
# exit(-1)
|
||||
if not success:
|
||||
# logger.error("初始化阶段 60s 没有返回数据,时间太长,终止流程")
|
||||
raise Exception("初始化阶段 60s 没有返回数据,时间太长,终止流程")
|
||||
else:
|
||||
logger.info("建立连接成功")
|
||||
self.connect_num = 0
|
||||
|
||||
def _send(self):
|
||||
send_ts = float(os.getenv("send_interval", 60))
|
||||
if not self.success:
|
||||
return
|
||||
|
||||
with open(self.context.file_path, "rb") as fp:
|
||||
wav_data = fp.read()
|
||||
meta_length = wav_data.index(b"data") + 8
|
||||
|
||||
try:
|
||||
with open(self.context.file_path, "rb") as fp:
|
||||
# 去掉 wav 文件的头信息
|
||||
fp.read(meta_length)
|
||||
# 上一段音频的发送时间
|
||||
last_send_time = -1
|
||||
# 正文内容
|
||||
while True:
|
||||
now_time = time.perf_counter()
|
||||
if last_send_time == -1:
|
||||
chunk = fp.read(int(self.context.chunk_size))
|
||||
else:
|
||||
interval_cnt = max(
|
||||
int((now_time - last_send_time) / self.context.wait_time),
|
||||
1,
|
||||
)
|
||||
chunk = fp.read(int(self.context.chunk_size * interval_cnt))
|
||||
if not chunk:
|
||||
break
|
||||
send_time_start = time.perf_counter()
|
||||
self.ws.send(chunk, websocket.ABNF.OPCODE_BINARY)
|
||||
self.send_time.append(send_time_start)
|
||||
last_send_time = send_time_start
|
||||
send_time_end = time.perf_counter()
|
||||
if send_time_end - send_time_start > send_ts:
|
||||
logger.error(f"发送延迟已经超过 {send_ts}s, 终止当前音频发送")
|
||||
break
|
||||
if (sleep_time := self.context.wait_time + now_time - send_time_end) > 0:
|
||||
time.sleep(sleep_time)
|
||||
logger.info("当条语音数据发送完成")
|
||||
self.ws.send(json.dumps({"end": True}))
|
||||
logger.info("2s 后关闭双向连接.")
|
||||
except BrokenPipeError:
|
||||
logger.error("发送数据出错,被测服务出现故障")
|
||||
except Exception as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(f"{traceback.print_exc()}")
|
||||
logger.error("发送数据失败")
|
||||
self.success = False
|
||||
# self.close_time = time.perf_counter() + int(os.getenv("api_timeout", 2))
|
||||
self.close_time = time.perf_counter() + 20 * 60
|
||||
|
||||
def _recv(self):
|
||||
try:
|
||||
while self.ws.connected and self.success:
|
||||
recv_data = self.ws.recv()
|
||||
if isinstance(recv_data, str):
|
||||
if recv_data := str(recv_data):
|
||||
self.recv_time.append(time.perf_counter())
|
||||
# 识别到最后的合并结果后再关闭
|
||||
recognition_results = StreamResultModel(**json.loads(recv_data)).recognition_results
|
||||
if (
|
||||
recognition_results.final_result
|
||||
and recognition_results.start_time == 0
|
||||
and recognition_results.end_time == 0
|
||||
and recognition_results.para_seq == 0
|
||||
):
|
||||
self.success = False
|
||||
else:
|
||||
self.predict_data.append(recv_data)
|
||||
# if recv_data.recognition_results.final_result and (IN_TEST or os.getenv('test')):
|
||||
# logger.info(f"recv_data {recv_data}")
|
||||
else:
|
||||
self.success = False
|
||||
raise Exception("返回的结果不是字符串形式")
|
||||
except websocket.WebSocketConnectionClosedException:
|
||||
logger.error("WebSocketConnectionClosedException")
|
||||
except ValidationError as e:
|
||||
logger.error("返回的结果不符合格式")
|
||||
logger.error(f"Exception is {e}")
|
||||
os._exit(1)
|
||||
except OSError:
|
||||
pass
|
||||
except Exception:
|
||||
logger.error(f"{traceback.print_exc()}")
|
||||
logger.error("处理被测服务返回数据时出错")
|
||||
self.success = False
|
||||
|
||||
def _close(self):
|
||||
while time.perf_counter() < self.close_time and self.success:
|
||||
# while not self.success:
|
||||
time.sleep(1)
|
||||
try:
|
||||
self.ws.close()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
|
||||
def _gen_result(self) -> dict:
|
||||
if not self.predict_data:
|
||||
logger.error("没有任何数据返回")
|
||||
self.predict_data = [StreamResultModel(**json.loads(data)).recognition_results for data in self.predict_data]
|
||||
# for item in self.predict_data:
|
||||
# if item.final_result and (IN_TEST or os.getenv('test')):
|
||||
# logger.info(f"recv_data {item}")
|
||||
|
||||
return {
|
||||
"fail": not self.predict_data,
|
||||
"send_time": self.send_time,
|
||||
"recv_time": self.recv_time,
|
||||
"predict_data": self.predict_data,
|
||||
}
|
||||
|
||||
def _gen_init_data(self) -> dict:
|
||||
return {
|
||||
"parameter": {
|
||||
"lang": self.context.lang,
|
||||
"sample_rate": self.context.sample_rate,
|
||||
"channel": self.context.channel,
|
||||
"format": self.context.audio_format,
|
||||
"bits": self.context.bits,
|
||||
"enable_words": self.context.enable_words,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user