196 lines
6.9 KiB
Python
196 lines
6.9 KiB
Python
import json, os, threading, time, traceback
|
||
from typing import (
|
||
Any, List
|
||
)
|
||
from copy import deepcopy
|
||
from utils.logger import logger
|
||
|
||
from websocket import (
|
||
create_connection,
|
||
WebSocketConnectionClosedException,
|
||
ABNF
|
||
)
|
||
|
||
from utils.model import (
|
||
ASRResponseModel,
|
||
SegmentModel
|
||
)
|
||
import threading
|
||
import queue
|
||
|
||
from pydantic_core import ValidationError
|
||
|
||
_IS_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH") is None
|
||
|
||
|
||
class ASRWebSocketClient:
|
||
def __init__(self, url: str):
|
||
self.endpoint = f"{url}/recognition"
|
||
# self.ctx = deepcopy(ctx)
|
||
self.ws = None
|
||
self.conn_attempts = -1
|
||
self.failed = False
|
||
self.terminate_time = float("inf")
|
||
self.sent_timestamps: List[float] = []
|
||
self.received_timestamps: List[float] = []
|
||
self.results: List[Any] = []
|
||
self.connected = True
|
||
logger.info(f"Target endpoint: {self.endpoint}")
|
||
|
||
def execute(self, path: str) -> List[SegmentModel]:
|
||
# 开启线程,一个发,一个接要记录接收数据的时间。
|
||
send_thread = threading.Thread(target=self.send, args=(path,))
|
||
# 记录开始时间
|
||
start_time = time.time()
|
||
# 启动线程发送数据
|
||
send_thread.start()
|
||
|
||
# 用来返回结果的队列
|
||
result_queue = queue.Queue()
|
||
|
||
# 线程封装 receive
|
||
def receive_thread_fn():
|
||
try:
|
||
res = self.receive(start_time)
|
||
result_queue.put(res)
|
||
except Exception as e:
|
||
# 放异常信息
|
||
result_queue.put(None)
|
||
|
||
# 启动 receive 线程(关注返回值)
|
||
receive_thread = threading.Thread(target=receive_thread_fn)
|
||
receive_thread.start()
|
||
|
||
# 可选:等待 send 线程也结束(可删)
|
||
send_thread.join()
|
||
receive_thread.join()
|
||
|
||
# 等待 receive 完成,并获取返回值
|
||
result = result_queue.get()
|
||
|
||
return result
|
||
|
||
def initialize_connection(self, language: str):
|
||
expiration = time.time() + float(os.getenv("end_time", "2"))
|
||
self.connected = False
|
||
init = False
|
||
while True:
|
||
try:
|
||
if not init:
|
||
logger.debug(f"建立ws链接:发送建立ws链接请求。")
|
||
self.ws = create_connection(self.endpoint)
|
||
body = json.dumps(self._get_init_payload(language))
|
||
logger.debug(f"建立ws链接:发送初始化数据。{body}")
|
||
self.ws.send(body)
|
||
init = True
|
||
msg = self.ws.recv()
|
||
logger.debug(f"收到响应数据: {msg}")
|
||
if len(msg) == 0:
|
||
time.sleep(0.5) # 睡眠一下等待数据写回来
|
||
continue
|
||
if isinstance(msg, str):
|
||
try:
|
||
msg = json.loads(msg)
|
||
except Exception:
|
||
raise Exception("建立ws链接:响应数据非json格式!")
|
||
if isinstance(msg, dict):
|
||
connected = msg.get("success")
|
||
if connected:
|
||
logger.debug("建立ws链接:链接建立成功!")
|
||
self.conn_attempts = 0
|
||
self.connected = True
|
||
return
|
||
else:
|
||
logger.info("建立ws链接:链接建立失败!")
|
||
init = False
|
||
self.conn_attempts = self.conn_attempts + 1
|
||
if self.conn_attempts > 5:
|
||
raise ConnectionRefusedError("重试5次后,仍然无法建立ws链接。")
|
||
if time.time() > expiration:
|
||
raise RuntimeError("建立ws链接:链接建立超时!")
|
||
|
||
except WebSocketConnectionClosedException or TimeoutError:
|
||
raise Exception("建立ws链接:初始化阶段连接中断,退出。")
|
||
except Exception as e:
|
||
logger.info("建立ws链接:链接建立失败!")
|
||
init = False
|
||
self.conn_attempts = self.conn_attempts + 1
|
||
if self.conn_attempts > 5:
|
||
raise ConnectionRefusedError("重试5次后,仍然无法建立ws链接。")
|
||
if time.time() > expiration:
|
||
raise RuntimeError("建立ws链接:链接建立超时!")
|
||
|
||
def shutdown(self):
|
||
try:
|
||
if self.ws:
|
||
self.ws.close()
|
||
except Exception:
|
||
pass
|
||
self.connected = False
|
||
|
||
def _get_init_payload(self, language="zh") -> dict:
|
||
# language = "zh"
|
||
return {
|
||
"language": language
|
||
}
|
||
|
||
def _get_finish_payload(self) -> dict:
|
||
return {
|
||
"end": "true"
|
||
}
|
||
|
||
def send(self, path):
|
||
skip_wav_header = path.endswith("wav")
|
||
with open(path, "rb") as f:
|
||
if skip_wav_header:
|
||
# WAV 文件头部为 44 字节
|
||
f.read(44)
|
||
|
||
while chunk := f.read(3200):
|
||
logger.debug(f"发送 {len(chunk)} 字节数据.")
|
||
self.ws.send(chunk, opcode=ABNF.OPCODE_BINARY)
|
||
time.sleep(0.1)
|
||
self.ws.send(json.dumps(self._get_finish_payload()))
|
||
|
||
def receive(self, start_time) -> [SegmentModel]:
|
||
results = []
|
||
while True:
|
||
msg = self.ws.recv()
|
||
# 记录读取到数据的时间(毫秒值)
|
||
now = 1000 * (time.time() - start_time)
|
||
logger.debug(f"{now} 收到响应数据: {msg}")
|
||
res = json.loads(msg)
|
||
if res.get("asr_results"):
|
||
item = ASRResponseModel.model_validate_json(msg).asr_results
|
||
item.receive_time = now
|
||
results.append(item)
|
||
# logger.info(item.summary())
|
||
else:
|
||
logger.debug(f"响应结束")
|
||
# 按照para_seq排序,并检查一下序号是否连续
|
||
results.sort(key=lambda x: x.para_seq)
|
||
|
||
missing_seqs = []
|
||
for i in range(len(results) - 1):
|
||
expected_next = results[i].para_seq + 1
|
||
actual_next = results[i + 1].para_seq
|
||
if actual_next != expected_next:
|
||
missing_seqs.extend(range(expected_next, actual_next))
|
||
|
||
if missing_seqs:
|
||
logger.warning(f"检测到丢失的段落序号:{missing_seqs}")
|
||
else:
|
||
logger.debug("响应数据正常")
|
||
|
||
return results
|
||
|
||
|
||
if __name__ == '__main__':
|
||
ws_client = ASRWebSocketClient("ws://localhost:18000")
|
||
ws_client.initialize_connection("zh")
|
||
|
||
res = ws_client.execute("/Users/yu/Documents/code-work/asr-live-iluvatar/zh_250312/zh/99.wav")
|
||
for i in res:
|
||
print(i.summary())
|
||
print()
|