Files
enginex-bi_series-vc-cnn/utils/client_callback.py

410 lines
15 KiB
Python
Raw Permalink Normal View History

2025-08-06 15:38:55 +08:00
import logging
import os
import threading
import time
from typing import Dict, List, Optional
import requests
from flask import Flask, abort, request
from pydantic import BaseModel, Field, ValidationError, field_validator
from schemas.dataset import QueryData
from schemas.stream import StreamDataModel
from utils.evaluator_plus import evaluate_editops, evaluate_punctuation
from .logger import log
MY_POD_IP = os.environ["MY_POD_IP"]
class StopException(Exception): ...
class EvaluateResult(BaseModel):
lang: str
cer: float
align_start: Dict[int, int] = Field(
description="句首字对齐时间差值(ms) -> 对齐数"
)
align_end: Dict[int, int] = Field(
description="句尾字对齐时间差值(ms) -> 对齐数"
)
first_word_distance_sum: float = Field(description="句首字距离总和(s)")
last_word_distance_sum: float = Field(description="句尾字距离总和(s)")
rtf: float = Field(description="翻译速度")
first_receive_delay: float = Field(description="首包接收延迟(s)")
query_count: int = Field(description="音频数")
voice_count: int = Field(description="句子数")
pred_punctuation_num: int = Field(description="预测标点数")
label_punctuation_num: int = Field(description="标注标点数")
pred_sentence_punctuation_num: int = Field(description="预测句子标点数")
label_setence_punctuation_num: int = Field(description="标注句子标点数")
preds: List[StreamDataModel] = Field(description="预测结果")
label: QueryData = Field(description="标注结果")
class ResultModel(BaseModel):
taskId: str
status: str
message: str = Field("")
recognition_results: Optional[StreamDataModel] = Field(None)
@field_validator("recognition_results", mode="after")
def convert_to_seconds(cls, v: Optional[StreamDataModel], values):
# 在这里处理除以1000的逻辑
if v is None:
return v
v.end_time = v.end_time / 1000
v.start_time = v.start_time / 1000
for word in v.words:
word.start_time /= 1000
word.end_time /= 1000
return v
class ClientCallback:
def __init__(self, sut_url: str, port: int):
self.sut_url = sut_url #sut_urlASR 服务的 URL如 http://asr-service:8080
self.port = port #port当前客户端监听的端口用于接收回调
#创建 Flask 应用并注册路由
self.app = Flask(__name__)
self.app.add_url_rule(
"/api/asr/batch-callback/<taskId>",
view_func=self.asr_callback,
methods=["POST"],
)
self.app.add_url_rule(
"/api/asr-runner/report",
view_func=self.heartbeat,
methods=["POST"],
)
"""
路由 1/api/asr/batch-callback/<taskId>
接收 ASR 服务的识别结果回调self.asr_callback 处理
taskId 是路径参数用于标识具体任务
路由 2/api/asr-runner/report
接收 ASR 服务的心跳检测请求self.heartbeat 处理
"""
logging.getLogger("werkzeug").disabled = True
threading.Thread(
target=self.app.run, args=("0.0.0.0", port), daemon=True
).start()
self.mutex = threading.Lock()
self.finished = threading.Event()
self.product_avaiable = True
self.reset()
def reset(self):
self.begin_time = None
self.end_time = None
self.first_receive_time = None
self.last_heartbeat_time = None
self.app_on = False
self.para_seq = 0
self.finished.clear()
self.error: Optional[str] = None
self.last_recognition_result: Optional[StreamDataModel] = None
self.recognition_results: List[StreamDataModel] = []
def asr_callback(self, taskId: str):
if self.app_on is False:
abort(400)
body = request.get_json(silent=True) # 静默解析JSON失败时返回None
if body is None:
abort(404)
try:
result = ResultModel.model_validate(body) #将回调的 JSON 数据解析为 ResultModel 对象,确保结构符合预期。
except ValidationError as e:
log.error("asr_callback: 结果格式错误: %s", e)
abort(404)
#处理任务完成状态FINISHED
if result.status == "FINISHED":
with self.mutex:
self.stop()
return "ok"
#处理非运行状态(非 RUNNING
if result.status != "RUNNING":
log.error(
"asr_callback: 结果状态错误: %s, message: %s",
result.status,
result.message,
)
abort(404)
recognition_result = result.recognition_results
if recognition_result is None:
log.error("asr_callback: 结果中没有recognition_results字段")
abort(404)
with self.mutex:
if not self.app_on:
log.error("asr_callback: 应用已结束")
abort(400)
if recognition_result.para_seq < self.para_seq:
error = "asr_callback: 结果中para_seq小于上一次的: %d < %d" % (
recognition_result.para_seq,
self.para_seq,
)
log.error(error)
if self.error is None:
self.error = error
self.stop()
abort(404)
if recognition_result.para_seq > self.para_seq + 1:
error = (
"asr_callback: 结果中para_seq大于上一次的+1 \
说明存在para_seq = %d没有final_result为True确认"
% (self.para_seq + 1,)
)
log.error(error)
if self.error is None:
self.error = error
self.stop()
abort(404)
if (
self.last_recognition_result is not None
and recognition_result.start_time
< self.last_recognition_result.end_time
):
error = "asr_callback: 结果中start_time小于上一次的end_time: %s < %s" % (
recognition_result.start_time,
self.last_recognition_result.end_time,
)
log.error(error)
if self.error is None:
self.error = error
self.stop()
abort(404)
self.recognition_results.append(recognition_result)
if recognition_result.final_result is True:
self.para_seq = recognition_result.para_seq
if self.last_recognition_result is None:
self.first_receive_time = time.time()
self.last_recognition_result = recognition_result
return "ok"
"""
def heartbeat(self):
if self.app_on is False:
abort(400)
body = request.get_json(silent=True)
if body is None:
abort(404)
status = body.get("status")
if status != "RUNNING":
message = body.get("message", "")
if message:
message = ", message: " + message
log.error("heartbeat: 状态错误: %s%s", status, message)
return "ok"
with self.mutex:
self.last_heartbeat_time = time.time()
return "ok"
"""
def predict(
self,
language: Optional[str],
audio_file: str,
audio_duration: float,
task_id: str,
):
#使用互斥锁确保线程安全
with self.mutex:
if self.app_on:
log.error("上一音频尚未完成处理,流程出现异常")
raise StopException()
self.reset()
self.app_on = True
#请求URLself.sut_url + "/predict"(如 http://localhost:8080/predict
resp = requests.post(
self.sut_url + "/predict",
data={
"language": language,
"taskId": task_id,
"progressCallbackUrl": "http://%s:%d/api/asr/batch-callback/%s"
% (MY_POD_IP, self.port, task_id),
"heartbeatUrl": "http://%s:%d/api/asr-runner/report" % (MY_POD_IP, self.port),
},
files={"file": (audio_file, open(audio_file, "rb"))},
timeout=60,
)
#响应处理
if resp.status_code != 200:
log.error("/predict接口返回http code %s", resp.status_code)
raise StopException()
resp.raise_for_status()
status = resp.json().get("status")
if status != "OK":
log.error("/predict接口返回非OK状态: %s", status)
raise StopException()
#辅助线程
threading.Thread(
target=self.dead_line_check, args=(audio_duration,), daemon=True
).start()
threading.Thread(target=self.heartbeat_check, daemon=True).start()
def dead_line_check(self, audio_duration: float):
begin_time = time.time()
self.begin_time = begin_time
# 初始化 10s 延迟检测
self.sleep_to(begin_time + 10)
with self.mutex:
if self.last_recognition_result is None:
error = "首包延迟内未收到返回"
log.error(error)
if self.error is None:
self.error = error
self.stop()
return
# 第一次30s检测
next_checktime = begin_time + 30
ddl = begin_time + max((audio_duration / 6) + 10, 30)
while time.time() < ddl:
self.sleep_to(next_checktime)
with self.mutex:
if self.finished.is_set():
return
if self.last_recognition_result is None:
error = "检测追赶线过程中获取最后一次识别结果异常"
log.error(error)
if self.error is None:
self.error = error
self.stop()
return
last_end_time = self.last_recognition_result.end_time
expect_end_time = (next_checktime - begin_time - 30) * 5.4
if last_end_time < expect_end_time:
log.warning(
"识别时间位置 %s 被死亡追赶线 %s 已追上,将置为产品不可用",
last_end_time,
expect_end_time,
)
self.product_avaiable = False
self.sleep_to(ddl)
break
next_checktime = last_end_time / 5.4 + begin_time + 30 + 1
next_checktime = min(next_checktime, ddl)
with self.mutex:
if self.finished.is_set():
return
log.warning("识别速度rtf低于1/6, 将置为产品不可用")
self.product_avaiable = False
self.sleep_to(begin_time + max((audio_duration / 3) + 10, 30))
with self.mutex:
if self.finished.is_set():
return
error = "处理时间超过ddl %s " % (ddl - begin_time)
log.error(error)
if self.error is None:
self.error = error
self.stop()
return
def heartbeat_check(self):
self.last_heartbeat_time = time.time()
while True:
with self.mutex:
if self.finished.is_set():
return
if time.time() - self.last_heartbeat_time > 30:
error = "asr_runner 心跳超时 %s" % (
time.time() - self.last_heartbeat_time
)
log.error(error)
if self.error is None:
self.error = error
self.stop()
return
time.sleep(5)
def sleep_to(self, to: float):
seconds = to - time.time()
if seconds <= 0:
return
time.sleep(seconds)
def stop(self):
self.end_time = time.time()
self.finished.set()
self.app_on = False
def evaluate(self, query_data: QueryData):
log.info("开始评估")
if (
self.begin_time is None
or self.end_time is None
or self.first_receive_time is None
):
if self.begin_time is None:
log.error("评估流程异常 无开始时间")
if self.end_time is None:
log.error("评估流程异常 无结束时间")
if self.first_receive_time is None:
log.error("评估流程异常 无首次接收时间")
raise StopException()
rtf = max(self.end_time - self.begin_time - 10, 0) / query_data.duration
first_receive_delay = max(self.first_receive_time - self.begin_time, 0)
query_count = 1
voice_count = len(query_data.voice)
preds = self.recognition_results
self.recognition_results = list(
filter(lambda x: x.final_result, self.recognition_results)
)
(
pred_punctuation_num,
label_punctuation_num,
pred_sentence_punctuation_num,
label_setence_punctuation_num,
) = evaluate_punctuation(query_data, self.recognition_results)
(
cer,
_,
align_start,
align_end,
first_word_distance_sum,
last_word_distance_sum,
) = evaluate_editops(query_data, self.recognition_results)
if align_start[300] / voice_count < 0.8:
log.warning(
"评估结果首字300ms对齐率 %s < 0.8, 将置为产品不可用",
align_start[300] / voice_count,
)
self.product_avaiable = False
return EvaluateResult(
lang=query_data.lang,
cer=cer,
align_start=align_start,
align_end=align_end,
first_word_distance_sum=first_word_distance_sum,
last_word_distance_sum=last_word_distance_sum,
rtf=rtf,
first_receive_delay=first_receive_delay,
query_count=query_count,
voice_count=voice_count,
pred_punctuation_num=pred_punctuation_num,
label_punctuation_num=label_punctuation_num,
pred_sentence_punctuation_num=pred_sentence_punctuation_num,
label_setence_punctuation_num=label_setence_punctuation_num,
preds=preds,
label=query_data,
)