initial commit
This commit is contained in:
195
utils/client.py
Normal file
195
utils/client.py
Normal file
@@ -0,0 +1,195 @@
|
||||
import json, os, threading, time, traceback
|
||||
from typing import (
|
||||
Any, List
|
||||
)
|
||||
from copy import deepcopy
|
||||
from utils.logger import logger
|
||||
|
||||
from websocket import (
|
||||
create_connection,
|
||||
WebSocketConnectionClosedException,
|
||||
ABNF
|
||||
)
|
||||
|
||||
from utils.model import (
|
||||
ASRResponseModel,
|
||||
SegmentModel
|
||||
)
|
||||
import threading
|
||||
import queue
|
||||
|
||||
from pydantic_core import ValidationError
|
||||
|
||||
_IS_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH") is None
|
||||
|
||||
|
||||
class ASRWebSocketClient:
|
||||
def __init__(self, url: str):
|
||||
self.endpoint = f"{url}/recognition"
|
||||
# self.ctx = deepcopy(ctx)
|
||||
self.ws = None
|
||||
self.conn_attempts = -1
|
||||
self.failed = False
|
||||
self.terminate_time = float("inf")
|
||||
self.sent_timestamps: List[float] = []
|
||||
self.received_timestamps: List[float] = []
|
||||
self.results: List[Any] = []
|
||||
self.connected = True
|
||||
logger.info(f"Target endpoint: {self.endpoint}")
|
||||
|
||||
def execute(self, path: str) -> List[SegmentModel]:
|
||||
# 开启线程,一个发,一个接要记录接收数据的时间。
|
||||
send_thread = threading.Thread(target=self.send, args=(path,))
|
||||
# 记录开始时间
|
||||
start_time = time.time()
|
||||
# 启动线程发送数据
|
||||
send_thread.start()
|
||||
|
||||
# 用来返回结果的队列
|
||||
result_queue = queue.Queue()
|
||||
|
||||
# 线程封装 receive
|
||||
def receive_thread_fn():
|
||||
try:
|
||||
res = self.receive(start_time)
|
||||
result_queue.put(res)
|
||||
except Exception as e:
|
||||
# 放异常信息
|
||||
result_queue.put(None)
|
||||
|
||||
# 启动 receive 线程(关注返回值)
|
||||
receive_thread = threading.Thread(target=receive_thread_fn)
|
||||
receive_thread.start()
|
||||
|
||||
# 可选:等待 send 线程也结束(可删)
|
||||
send_thread.join()
|
||||
receive_thread.join()
|
||||
|
||||
# 等待 receive 完成,并获取返回值
|
||||
result = result_queue.get()
|
||||
|
||||
return result
|
||||
|
||||
def initialize_connection(self, language: str):
|
||||
expiration = time.time() + float(os.getenv("end_time", "2"))
|
||||
self.connected = False
|
||||
init = False
|
||||
while True:
|
||||
try:
|
||||
if not init:
|
||||
logger.debug(f"建立ws链接:发送建立ws链接请求。")
|
||||
self.ws = create_connection(self.endpoint)
|
||||
body = json.dumps(self._get_init_payload(language))
|
||||
logger.debug(f"建立ws链接:发送初始化数据。{body}")
|
||||
self.ws.send(body)
|
||||
init = True
|
||||
msg = self.ws.recv()
|
||||
logger.debug(f"收到响应数据: {msg}")
|
||||
if len(msg) == 0:
|
||||
time.sleep(0.5) # 睡眠一下等待数据写回来
|
||||
continue
|
||||
if isinstance(msg, str):
|
||||
try:
|
||||
msg = json.loads(msg)
|
||||
except Exception:
|
||||
raise Exception("建立ws链接:响应数据非json格式!")
|
||||
if isinstance(msg, dict):
|
||||
connected = msg.get("success")
|
||||
if connected:
|
||||
logger.debug("建立ws链接:链接建立成功!")
|
||||
self.conn_attempts = 0
|
||||
self.connected = True
|
||||
return
|
||||
else:
|
||||
logger.info("建立ws链接:链接建立失败!")
|
||||
init = False
|
||||
self.conn_attempts = self.conn_attempts + 1
|
||||
if self.conn_attempts > 5:
|
||||
raise ConnectionRefusedError("重试5次后,仍然无法建立ws链接。")
|
||||
if time.time() > expiration:
|
||||
raise RuntimeError("建立ws链接:链接建立超时!")
|
||||
|
||||
except WebSocketConnectionClosedException or TimeoutError:
|
||||
raise Exception("建立ws链接:初始化阶段连接中断,退出。")
|
||||
except Exception as e:
|
||||
logger.info("建立ws链接:链接建立失败!")
|
||||
init = False
|
||||
self.conn_attempts = self.conn_attempts + 1
|
||||
if self.conn_attempts > 5:
|
||||
raise ConnectionRefusedError("重试5次后,仍然无法建立ws链接。")
|
||||
if time.time() > expiration:
|
||||
raise RuntimeError("建立ws链接:链接建立超时!")
|
||||
|
||||
def shutdown(self):
|
||||
try:
|
||||
if self.ws:
|
||||
self.ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.connected = False
|
||||
|
||||
def _get_init_payload(self, language="zh") -> dict:
|
||||
# language = "zh"
|
||||
return {
|
||||
"language": language
|
||||
}
|
||||
|
||||
def _get_finish_payload(self) -> dict:
|
||||
return {
|
||||
"end": "true"
|
||||
}
|
||||
|
||||
def send(self, path):
|
||||
skip_wav_header = path.endswith("wav")
|
||||
with open(path, "rb") as f:
|
||||
if skip_wav_header:
|
||||
# WAV 文件头部为 44 字节
|
||||
f.read(44)
|
||||
|
||||
while chunk := f.read(3200):
|
||||
logger.debug(f"发送 {len(chunk)} 字节数据.")
|
||||
self.ws.send(chunk, opcode=ABNF.OPCODE_BINARY)
|
||||
time.sleep(0.1)
|
||||
self.ws.send(json.dumps(self._get_finish_payload()))
|
||||
|
||||
def receive(self, start_time) -> [SegmentModel]:
|
||||
results = []
|
||||
while True:
|
||||
msg = self.ws.recv()
|
||||
# 记录读取到数据的时间(毫秒值)
|
||||
now = 1000 * (time.time() - start_time)
|
||||
logger.debug(f"{now} 收到响应数据: {msg}")
|
||||
res = json.loads(msg)
|
||||
if res.get("asr_results"):
|
||||
item = ASRResponseModel.model_validate_json(msg).asr_results
|
||||
item.receive_time = now
|
||||
results.append(item)
|
||||
# logger.info(item.summary())
|
||||
else:
|
||||
logger.debug(f"响应结束")
|
||||
# 按照para_seq排序,并检查一下序号是否连续
|
||||
results.sort(key=lambda x: x.para_seq)
|
||||
|
||||
missing_seqs = []
|
||||
for i in range(len(results) - 1):
|
||||
expected_next = results[i].para_seq + 1
|
||||
actual_next = results[i + 1].para_seq
|
||||
if actual_next != expected_next:
|
||||
missing_seqs.extend(range(expected_next, actual_next))
|
||||
|
||||
if missing_seqs:
|
||||
logger.warning(f"检测到丢失的段落序号:{missing_seqs}")
|
||||
else:
|
||||
logger.debug("响应数据正常")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ws_client = ASRWebSocketClient("ws://localhost:18000")
|
||||
ws_client.initialize_connection("zh")
|
||||
|
||||
res = ws_client.execute("/Users/yu/Documents/code-work/asr-live-iluvatar/zh_250312/zh/99.wav")
|
||||
for i in res:
|
||||
print(i.summary())
|
||||
print()
|
||||
Reference in New Issue
Block a user