Files
enginex-c_series-asr/utils/client.py
aceforeverd a4ec58a45e init
2025-08-28 18:46:56 +08:00

196 lines
6.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json, os, threading, time, traceback
from typing import (
Any, List
)
from copy import deepcopy
from utils.logger import logger
from websocket import (
create_connection,
WebSocketConnectionClosedException,
ABNF
)
from utils.model import (
ASRResponseModel,
SegmentModel
)
import threading
import queue
from pydantic_core import ValidationError
_IS_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH") is None
class ASRWebSocketClient:
def __init__(self, url: str):
self.endpoint = f"{url}/recognition"
# self.ctx = deepcopy(ctx)
self.ws = None
self.conn_attempts = -1
self.failed = False
self.terminate_time = float("inf")
self.sent_timestamps: List[float] = []
self.received_timestamps: List[float] = []
self.results: List[Any] = []
self.connected = True
logger.info(f"Target endpoint: {self.endpoint}")
def execute(self, path: str) -> List[SegmentModel]:
# 开启线程,一个发,一个接要记录接收数据的时间。
send_thread = threading.Thread(target=self.send, args=(path,))
# 记录开始时间
start_time = time.time()
# 启动线程发送数据
send_thread.start()
# 用来返回结果的队列
result_queue = queue.Queue()
# 线程封装 receive
def receive_thread_fn():
try:
res = self.receive(start_time)
result_queue.put(res)
except Exception as e:
# 放异常信息
result_queue.put(None)
# 启动 receive 线程(关注返回值)
receive_thread = threading.Thread(target=receive_thread_fn)
receive_thread.start()
# 可选:等待 send 线程也结束(可删)
send_thread.join()
receive_thread.join()
# 等待 receive 完成,并获取返回值
result = result_queue.get()
return result
def initialize_connection(self, language: str):
expiration = time.time() + float(os.getenv("end_time", "2"))
self.connected = False
init = False
while True:
try:
if not init:
logger.debug(f"建立ws链接发送建立ws链接请求。")
self.ws = create_connection(self.endpoint)
body = json.dumps(self._get_init_payload(language))
logger.debug(f"建立ws链接发送初始化数据。{body}")
self.ws.send(body)
init = True
msg = self.ws.recv()
logger.debug(f"收到响应数据: {msg}")
if len(msg) == 0:
time.sleep(0.5) # 睡眠一下等待数据写回来
continue
if isinstance(msg, str):
try:
msg = json.loads(msg)
except Exception:
raise Exception("建立ws链接响应数据非json格式")
if isinstance(msg, dict):
connected = msg.get("success")
if connected:
logger.debug("建立ws链接链接建立成功")
self.conn_attempts = 0
self.connected = True
return
else:
logger.info("建立ws链接链接建立失败")
init = False
self.conn_attempts = self.conn_attempts + 1
if self.conn_attempts > 5:
raise ConnectionRefusedError("重试5次后仍然无法建立ws链接。")
if time.time() > expiration:
raise RuntimeError("建立ws链接链接建立超时")
except WebSocketConnectionClosedException or TimeoutError:
raise Exception("建立ws链接初始化阶段连接中断退出。")
except Exception as e:
logger.info("建立ws链接链接建立失败")
init = False
self.conn_attempts = self.conn_attempts + 1
if self.conn_attempts > 5:
raise ConnectionRefusedError("重试5次后仍然无法建立ws链接。")
if time.time() > expiration:
raise RuntimeError("建立ws链接链接建立超时")
def shutdown(self):
try:
if self.ws:
self.ws.close()
except Exception:
pass
self.connected = False
def _get_init_payload(self, language="zh") -> dict:
# language = "zh"
return {
"language": language
}
def _get_finish_payload(self) -> dict:
return {
"end": "true"
}
def send(self, path):
skip_wav_header = path.endswith("wav")
with open(path, "rb") as f:
if skip_wav_header:
# WAV 文件头部为 44 字节
f.read(44)
while chunk := f.read(3200):
logger.debug(f"发送 {len(chunk)} 字节数据.")
self.ws.send(chunk, opcode=ABNF.OPCODE_BINARY)
time.sleep(0.1)
self.ws.send(json.dumps(self._get_finish_payload()))
def receive(self, start_time) -> [SegmentModel]:
results = []
while True:
msg = self.ws.recv()
# 记录读取到数据的时间(毫秒值)
now = 1000 * (time.time() - start_time)
logger.debug(f"{now} 收到响应数据: {msg}")
res = json.loads(msg)
if res.get("asr_results"):
item = ASRResponseModel.model_validate_json(msg).asr_results
item.receive_time = now
results.append(item)
# logger.info(item.summary())
else:
logger.debug(f"响应结束")
# 按照para_seq排序并检查一下序号是否连续
results.sort(key=lambda x: x.para_seq)
missing_seqs = []
for i in range(len(results) - 1):
expected_next = results[i].para_seq + 1
actual_next = results[i + 1].para_seq
if actual_next != expected_next:
missing_seqs.extend(range(expected_next, actual_next))
if missing_seqs:
logger.warning(f"检测到丢失的段落序号:{missing_seqs}")
else:
logger.debug("响应数据正常")
return results
if __name__ == '__main__':
ws_client = ASRWebSocketClient("ws://localhost:18000")
ws_client.initialize_connection("zh")
res = ws_client.execute("/Users/yu/Documents/code-work/asr-live-iluvatar/zh_250312/zh/99.wav")
for i in res:
print(i.summary())
print()