278 lines
11 KiB
Python
278 lines
11 KiB
Python
import asyncio
|
||
import json
|
||
import os
|
||
import time
|
||
import traceback
|
||
from copy import deepcopy
|
||
from enum import Enum
|
||
from typing import Any, List
|
||
|
||
import websockets
|
||
from pydantic_core import ValidationError
|
||
|
||
from schemas.context import ASRContext
|
||
from schemas.stream import StreamResultModel, StreamWordsModel
|
||
from utils.logger import logger
|
||
|
||
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", None) is None
|
||
|
||
|
||
class STATUS_DATA(str, Enum):
|
||
WAITING_FIRST_INIT = "waiting_first_init"
|
||
FIRST_FAIL = "fail"
|
||
WAITING_SECOND_INIT = "waiting_second_init"
|
||
SECOND_INIT = "second_fail"
|
||
WAITING_THIRD_INIT = "waiting_third_init"
|
||
THIRD_INIT = "third_fail"
|
||
SUCCESS = "success"
|
||
CLOSED = "closed"
|
||
|
||
|
||
class ClientAsync:
|
||
def __init__(self, sut_url: str, context: ASRContext, idx: int) -> None:
|
||
# base_url = "ws://127.0.0.1:5003"
|
||
self.base_url = sut_url + "/recognition"
|
||
self.context: ASRContext = deepcopy(context)
|
||
self.idx = idx
|
||
# if not os.getenv("DATASET_FILEPATH", ""):
|
||
# self.base_url = "wss://speech.4paradigm.com/aibuds/api/v1/recognition"
|
||
# self.base_url = "ws://localhost:5003/recognition"
|
||
self.fail_count = 0
|
||
self.close_time = 10**50
|
||
self.send_time: List[float] = []
|
||
self.recv_time: List[float] = []
|
||
self.predict_data: List[Any] = []
|
||
|
||
async def _sender(
|
||
self, websocket: websockets.WebSocketClientProtocol, send_queue: asyncio.Queue, recv_queue: asyncio.Queue
|
||
):
|
||
# 设置 websocket 缓冲区大小
|
||
websocket.transport.set_write_buffer_limits(1024 * 1024 * 1024)
|
||
|
||
# 发送初始化数据
|
||
await websocket.send(json.dumps(self._gen_init_data()))
|
||
await send_queue.put(STATUS_DATA.WAITING_FIRST_INIT)
|
||
connect_status = await recv_queue.get()
|
||
if connect_status == STATUS_DATA.FIRST_FAIL:
|
||
return
|
||
|
||
# 开始发送音频
|
||
with open(self.context.file_path, "rb") as fp:
|
||
wav_data = fp.read()
|
||
meta_length = wav_data.index(b"data") + 8
|
||
try:
|
||
with open(self.context.file_path, "rb") as fp:
|
||
# 去掉 wav 文件的头信息
|
||
fp.read(meta_length)
|
||
wav_time = 0.0
|
||
label_id = 0
|
||
char_contains_rate_checktime = []
|
||
char_contains_rate_checktime_id = 0
|
||
while True:
|
||
now_time = time.perf_counter()
|
||
chunk = fp.read(int(self.context.chunk_size))
|
||
if not chunk:
|
||
break
|
||
wav_time += self.context.wait_time
|
||
try:
|
||
self.send_time.append(time.perf_counter())
|
||
await asyncio.wait_for(websocket.send(chunk), timeout=0.08)
|
||
except asyncio.exceptions.TimeoutError:
|
||
pass
|
||
while label_id < len(self.context.labels) and wav_time >= self.context.labels[label_id].start:
|
||
char_contains_rate_checktime.append(now_time + 3.0)
|
||
label_id += 1
|
||
predict_text_len = sum(map(lambda x: len(x.text), self.predict_data))
|
||
while char_contains_rate_checktime_id < len(char_contains_rate_checktime) and \
|
||
char_contains_rate_checktime[char_contains_rate_checktime_id] <= now_time:
|
||
label_text_len = sum(
|
||
map(lambda x: len(x.answer),
|
||
self.context.labels[:char_contains_rate_checktime_id+1]))
|
||
if predict_text_len / self.context.char_contains_rate < label_text_len:
|
||
self.context.fail_char_contains_rate_num += 1
|
||
char_contains_rate_checktime_id += 1
|
||
await asyncio.sleep(max(0, self.context.wait_time - (time.perf_counter() - now_time)))
|
||
await websocket.send(json.dumps({"end": True}))
|
||
logger.info(f"第 {self.idx} 条数据,当条语音数据发送完成")
|
||
logger.info(f"第 {self.idx} 条数据,3s 后关闭双向连接.")
|
||
self.close_time = time.perf_counter() + 3
|
||
except websockets.exceptions.ConnectionClosedError:
|
||
logger.error(f"第 {self.idx} 条数据发送过程中,连接断开")
|
||
except Exception:
|
||
logger.error(f"{traceback.print_exc()}")
|
||
logger.error(f"第 {self.idx} 条数据,发送数据失败")
|
||
|
||
async def _recv(
|
||
self, websocket: websockets.WebSocketClientProtocol, send_queue: asyncio.Queue, recv_queue: asyncio.Queue
|
||
):
|
||
await recv_queue.get()
|
||
try:
|
||
await asyncio.wait_for(websocket.recv(), timeout=2)
|
||
except asyncio.exceptions.TimeoutError:
|
||
await send_queue.put(STATUS_DATA.FIRST_FAIL)
|
||
logger.info(f"第 {self.idx} 条数据,初始化阶段, 2s 没收到 success 返回,超时了")
|
||
self.fail_count += 1
|
||
return
|
||
except Exception as e:
|
||
await send_queue.put(STATUS_DATA.FIRST_FAIL)
|
||
logger.error(f"第 {self.idx} 条数据,初始化阶段, 收到异常:{e}")
|
||
self.fail_count += 1
|
||
return
|
||
else:
|
||
await send_queue.put(STATUS_DATA.SUCCESS)
|
||
|
||
# 开始接收语音识别结果
|
||
try:
|
||
while websocket.open:
|
||
# 接收数据
|
||
recv_data = await websocket.recv()
|
||
if isinstance(recv_data, str):
|
||
self.recv_time.append(time.perf_counter())
|
||
recv_data = str(recv_data)
|
||
recv_data = json.loads(recv_data)
|
||
result = StreamResultModel(**recv_data)
|
||
recognition_results = result.asr_results
|
||
if (
|
||
recognition_results.final_result
|
||
and not recognition_results.language
|
||
and recognition_results.start_time == 0
|
||
and recognition_results.end_time == 0
|
||
and recognition_results.para_seq == 0
|
||
):
|
||
pass
|
||
else:
|
||
self.predict_data.append(recognition_results)
|
||
else:
|
||
raise Exception("返回的结果不是字符串形式")
|
||
except websockets.exceptions.ConnectionClosedOK:
|
||
pass
|
||
except websockets.exceptions.ConnectionClosedError:
|
||
pass
|
||
except ValidationError as e:
|
||
logger.error(f"第 {self.idx} 条数据,返回的结果不符合格式")
|
||
logger.error(f"Exception is {e}")
|
||
os._exit(1)
|
||
except OSError:
|
||
pass
|
||
except Exception:
|
||
logger.error(f"{traceback.print_exc()}")
|
||
logger.error(f"第 {self.idx} 条数据,处理被测服务返回数据时出错")
|
||
|
||
async def _action(self):
|
||
logger.info(f"第 {self.idx} 条数据开始测试")
|
||
|
||
while self.fail_count < 3:
|
||
|
||
send_queue = asyncio.Queue()
|
||
recv_queue = asyncio.Queue()
|
||
|
||
self.send_time: List[float] = []
|
||
self.recv_time: List[float] = []
|
||
self.predict_data: List[Any] = []
|
||
|
||
async with websockets.connect(self.base_url) as websocket:
|
||
send_task = asyncio.create_task(self._sender(websocket, send_queue, recv_queue))
|
||
recv_task = asyncio.create_task(self._recv(websocket, recv_queue, send_queue))
|
||
|
||
await asyncio.gather(send_task)
|
||
await asyncio.sleep(3)
|
||
|
||
await asyncio.gather(recv_task)
|
||
|
||
if self.send_time:
|
||
break
|
||
else:
|
||
self.fail_count += 1
|
||
logger.info(f"第 {self.idx} 条数据,初始化阶段, 第 {self.fail_count} 次失败, 1s 后重试")
|
||
time.sleep(1)
|
||
|
||
def action(self):
|
||
asyncio.run(self._action())
|
||
return self._gen_result()
|
||
|
||
def _gen_result(self) -> ASRContext:
|
||
if not self.predict_data:
|
||
logger.error(f"第 {self.idx} 条数据,没有任何数据返回")
|
||
self.context.append_preds(self.predict_data, self.send_time, self.recv_time)
|
||
self.context.fail = not self.predict_data
|
||
|
||
punctuation_words: List[StreamWordsModel] = []
|
||
for pred in self.predict_data:
|
||
punctuations = [",", ".", "!", "?"]
|
||
if pred.language == "zh":
|
||
punctuations = [",", "。", "!", "?"]
|
||
elif pred.language == "ja":
|
||
punctuations = ["、", "。", "!", "?"]
|
||
elif pred.language in ("ar", "fa"):
|
||
punctuations = ["،", ".", "!", "؟"]
|
||
elif pred.language == "el":
|
||
punctuations = [",", ".", "!", ";"]
|
||
elif pred.language == "ti":
|
||
punctuations = ["།"]
|
||
|
||
for word in pred.words:
|
||
if word.text in punctuations:
|
||
punctuation_words.append(word)
|
||
start_times = list(map(lambda x: x.start_time, punctuation_words))
|
||
start_times = sorted(start_times)
|
||
end_times = list(map(lambda x: x.end_time, punctuation_words))
|
||
end_times = sorted(end_times)
|
||
|
||
self.context.punctuation_num = len(self.context.labels)
|
||
label_n = len(self.context.labels)
|
||
for i, label in enumerate(self.context.labels):
|
||
label_left = (label.end - 0.7)
|
||
label_right = (label.end + 0.7)
|
||
if i < label_n - 1:
|
||
label_left = label.end
|
||
label_right = self.context.labels[i+1].start
|
||
|
||
exist = False
|
||
|
||
def upper_bound(x: float, lst: List[float]) -> int:
|
||
ans = -1
|
||
left, right = 0, len(lst) - 1
|
||
while left <= right:
|
||
mid = (left + right) // 2
|
||
if lst[mid] >= x:
|
||
ans = mid
|
||
right = mid - 1
|
||
else:
|
||
left = mid + 1
|
||
return ans
|
||
|
||
def lower_bound(x: float, lst: List[float]) -> int:
|
||
ans = -1
|
||
left, right = 0, len(lst) - 1
|
||
while left <= right:
|
||
mid = (left + right) // 2
|
||
if lst[mid] <= x:
|
||
ans = mid
|
||
left = mid + 1
|
||
else:
|
||
right = mid - 1
|
||
return ans
|
||
|
||
left_in_pred = upper_bound(label_left, start_times)
|
||
if left_in_pred != -1 and start_times[left_in_pred] <= label_right:
|
||
exist = True
|
||
right_in_pred = lower_bound(label_right, end_times)
|
||
if right_in_pred != -1 and end_times[right_in_pred] >= label_left:
|
||
exist = True
|
||
|
||
if exist:
|
||
self.context.pred_punctuation_num += 1
|
||
return self.context
|
||
|
||
def _gen_init_data(self) -> dict:
|
||
return {
|
||
"parameter": {
|
||
"lang": None,
|
||
"sample_rate": self.context.sample_rate,
|
||
"channel": self.context.channel,
|
||
"format": self.context.audio_format,
|
||
"bits": self.context.bits,
|
||
"enable_words": self.context.enable_words,
|
||
}
|
||
}
|