Files
enginex-bi_series-vc-cnn/utils/client_callback.py
zhousha 55a67e817e update
2025-08-06 15:38:55 +08:00

410 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import os
import threading
import time
from typing import Dict, List, Optional
import requests
from flask import Flask, abort, request
from pydantic import BaseModel, Field, ValidationError, field_validator
from schemas.dataset import QueryData
from schemas.stream import StreamDataModel
from utils.evaluator_plus import evaluate_editops, evaluate_punctuation
from .logger import log
MY_POD_IP = os.environ["MY_POD_IP"]
class StopException(Exception): ...
class EvaluateResult(BaseModel):
lang: str
cer: float
align_start: Dict[int, int] = Field(
description="句首字对齐时间差值(ms) -> 对齐数"
)
align_end: Dict[int, int] = Field(
description="句尾字对齐时间差值(ms) -> 对齐数"
)
first_word_distance_sum: float = Field(description="句首字距离总和(s)")
last_word_distance_sum: float = Field(description="句尾字距离总和(s)")
rtf: float = Field(description="翻译速度")
first_receive_delay: float = Field(description="首包接收延迟(s)")
query_count: int = Field(description="音频数")
voice_count: int = Field(description="句子数")
pred_punctuation_num: int = Field(description="预测标点数")
label_punctuation_num: int = Field(description="标注标点数")
pred_sentence_punctuation_num: int = Field(description="预测句子标点数")
label_setence_punctuation_num: int = Field(description="标注句子标点数")
preds: List[StreamDataModel] = Field(description="预测结果")
label: QueryData = Field(description="标注结果")
class ResultModel(BaseModel):
taskId: str
status: str
message: str = Field("")
recognition_results: Optional[StreamDataModel] = Field(None)
@field_validator("recognition_results", mode="after")
def convert_to_seconds(cls, v: Optional[StreamDataModel], values):
# 在这里处理除以1000的逻辑
if v is None:
return v
v.end_time = v.end_time / 1000
v.start_time = v.start_time / 1000
for word in v.words:
word.start_time /= 1000
word.end_time /= 1000
return v
class ClientCallback:
def __init__(self, sut_url: str, port: int):
self.sut_url = sut_url #sut_urlASR 服务的 URL如 http://asr-service:8080
self.port = port #port当前客户端监听的端口用于接收回调
#创建 Flask 应用并注册路由
self.app = Flask(__name__)
self.app.add_url_rule(
"/api/asr/batch-callback/<taskId>",
view_func=self.asr_callback,
methods=["POST"],
)
self.app.add_url_rule(
"/api/asr-runner/report",
view_func=self.heartbeat,
methods=["POST"],
)
"""
路由 1/api/asr/batch-callback/<taskId>
接收 ASR 服务的识别结果回调self.asr_callback 处理)。
taskId 是路径参数,用于标识具体任务。
路由 2/api/asr-runner/report
接收 ASR 服务的心跳检测请求self.heartbeat 处理)。
"""
logging.getLogger("werkzeug").disabled = True
threading.Thread(
target=self.app.run, args=("0.0.0.0", port), daemon=True
).start()
self.mutex = threading.Lock()
self.finished = threading.Event()
self.product_avaiable = True
self.reset()
def reset(self):
self.begin_time = None
self.end_time = None
self.first_receive_time = None
self.last_heartbeat_time = None
self.app_on = False
self.para_seq = 0
self.finished.clear()
self.error: Optional[str] = None
self.last_recognition_result: Optional[StreamDataModel] = None
self.recognition_results: List[StreamDataModel] = []
def asr_callback(self, taskId: str):
if self.app_on is False:
abort(400)
body = request.get_json(silent=True) # 静默解析JSON失败时返回None
if body is None:
abort(404)
try:
result = ResultModel.model_validate(body) #将回调的 JSON 数据解析为 ResultModel 对象,确保结构符合预期。
except ValidationError as e:
log.error("asr_callback: 结果格式错误: %s", e)
abort(404)
#处理任务完成状态FINISHED
if result.status == "FINISHED":
with self.mutex:
self.stop()
return "ok"
#处理非运行状态(非 RUNNING
if result.status != "RUNNING":
log.error(
"asr_callback: 结果状态错误: %s, message: %s",
result.status,
result.message,
)
abort(404)
recognition_result = result.recognition_results
if recognition_result is None:
log.error("asr_callback: 结果中没有recognition_results字段")
abort(404)
with self.mutex:
if not self.app_on:
log.error("asr_callback: 应用已结束")
abort(400)
if recognition_result.para_seq < self.para_seq:
error = "asr_callback: 结果中para_seq小于上一次的: %d < %d" % (
recognition_result.para_seq,
self.para_seq,
)
log.error(error)
if self.error is None:
self.error = error
self.stop()
abort(404)
if recognition_result.para_seq > self.para_seq + 1:
error = (
"asr_callback: 结果中para_seq大于上一次的+1 \
说明存在para_seq = %d没有final_result为True确认"
% (self.para_seq + 1,)
)
log.error(error)
if self.error is None:
self.error = error
self.stop()
abort(404)
if (
self.last_recognition_result is not None
and recognition_result.start_time
< self.last_recognition_result.end_time
):
error = "asr_callback: 结果中start_time小于上一次的end_time: %s < %s" % (
recognition_result.start_time,
self.last_recognition_result.end_time,
)
log.error(error)
if self.error is None:
self.error = error
self.stop()
abort(404)
self.recognition_results.append(recognition_result)
if recognition_result.final_result is True:
self.para_seq = recognition_result.para_seq
if self.last_recognition_result is None:
self.first_receive_time = time.time()
self.last_recognition_result = recognition_result
return "ok"
"""
def heartbeat(self):
if self.app_on is False:
abort(400)
body = request.get_json(silent=True)
if body is None:
abort(404)
status = body.get("status")
if status != "RUNNING":
message = body.get("message", "")
if message:
message = ", message: " + message
log.error("heartbeat: 状态错误: %s%s", status, message)
return "ok"
with self.mutex:
self.last_heartbeat_time = time.time()
return "ok"
"""
def predict(
self,
language: Optional[str],
audio_file: str,
audio_duration: float,
task_id: str,
):
#使用互斥锁确保线程安全
with self.mutex:
if self.app_on:
log.error("上一音频尚未完成处理,流程出现异常")
raise StopException()
self.reset()
self.app_on = True
#请求URLself.sut_url + "/predict"(如 http://localhost:8080/predict
resp = requests.post(
self.sut_url + "/predict",
data={
"language": language,
"taskId": task_id,
"progressCallbackUrl": "http://%s:%d/api/asr/batch-callback/%s"
% (MY_POD_IP, self.port, task_id),
"heartbeatUrl": "http://%s:%d/api/asr-runner/report" % (MY_POD_IP, self.port),
},
files={"file": (audio_file, open(audio_file, "rb"))},
timeout=60,
)
#响应处理
if resp.status_code != 200:
log.error("/predict接口返回http code %s", resp.status_code)
raise StopException()
resp.raise_for_status()
status = resp.json().get("status")
if status != "OK":
log.error("/predict接口返回非OK状态: %s", status)
raise StopException()
#辅助线程
threading.Thread(
target=self.dead_line_check, args=(audio_duration,), daemon=True
).start()
threading.Thread(target=self.heartbeat_check, daemon=True).start()
def dead_line_check(self, audio_duration: float):
begin_time = time.time()
self.begin_time = begin_time
# 初始化 10s 延迟检测
self.sleep_to(begin_time + 10)
with self.mutex:
if self.last_recognition_result is None:
error = "首包延迟内未收到返回"
log.error(error)
if self.error is None:
self.error = error
self.stop()
return
# 第一次30s检测
next_checktime = begin_time + 30
ddl = begin_time + max((audio_duration / 6) + 10, 30)
while time.time() < ddl:
self.sleep_to(next_checktime)
with self.mutex:
if self.finished.is_set():
return
if self.last_recognition_result is None:
error = "检测追赶线过程中获取最后一次识别结果异常"
log.error(error)
if self.error is None:
self.error = error
self.stop()
return
last_end_time = self.last_recognition_result.end_time
expect_end_time = (next_checktime - begin_time - 30) * 5.4
if last_end_time < expect_end_time:
log.warning(
"识别时间位置 %s 被死亡追赶线 %s 已追上,将置为产品不可用",
last_end_time,
expect_end_time,
)
self.product_avaiable = False
self.sleep_to(ddl)
break
next_checktime = last_end_time / 5.4 + begin_time + 30 + 1
next_checktime = min(next_checktime, ddl)
with self.mutex:
if self.finished.is_set():
return
log.warning("识别速度rtf低于1/6, 将置为产品不可用")
self.product_avaiable = False
self.sleep_to(begin_time + max((audio_duration / 3) + 10, 30))
with self.mutex:
if self.finished.is_set():
return
error = "处理时间超过ddl %s " % (ddl - begin_time)
log.error(error)
if self.error is None:
self.error = error
self.stop()
return
def heartbeat_check(self):
self.last_heartbeat_time = time.time()
while True:
with self.mutex:
if self.finished.is_set():
return
if time.time() - self.last_heartbeat_time > 30:
error = "asr_runner 心跳超时 %s" % (
time.time() - self.last_heartbeat_time
)
log.error(error)
if self.error is None:
self.error = error
self.stop()
return
time.sleep(5)
def sleep_to(self, to: float):
seconds = to - time.time()
if seconds <= 0:
return
time.sleep(seconds)
def stop(self):
self.end_time = time.time()
self.finished.set()
self.app_on = False
def evaluate(self, query_data: QueryData):
log.info("开始评估")
if (
self.begin_time is None
or self.end_time is None
or self.first_receive_time is None
):
if self.begin_time is None:
log.error("评估流程异常 无开始时间")
if self.end_time is None:
log.error("评估流程异常 无结束时间")
if self.first_receive_time is None:
log.error("评估流程异常 无首次接收时间")
raise StopException()
rtf = max(self.end_time - self.begin_time - 10, 0) / query_data.duration
first_receive_delay = max(self.first_receive_time - self.begin_time, 0)
query_count = 1
voice_count = len(query_data.voice)
preds = self.recognition_results
self.recognition_results = list(
filter(lambda x: x.final_result, self.recognition_results)
)
(
pred_punctuation_num,
label_punctuation_num,
pred_sentence_punctuation_num,
label_setence_punctuation_num,
) = evaluate_punctuation(query_data, self.recognition_results)
(
cer,
_,
align_start,
align_end,
first_word_distance_sum,
last_word_distance_sum,
) = evaluate_editops(query_data, self.recognition_results)
if align_start[300] / voice_count < 0.8:
log.warning(
"评估结果首字300ms对齐率 %s < 0.8, 将置为产品不可用",
align_start[300] / voice_count,
)
self.product_avaiable = False
return EvaluateResult(
lang=query_data.lang,
cer=cer,
align_start=align_start,
align_end=align_end,
first_word_distance_sum=first_word_distance_sum,
last_word_distance_sum=last_word_distance_sum,
rtf=rtf,
first_receive_delay=first_receive_delay,
query_count=query_count,
voice_count=voice_count,
pred_punctuation_num=pred_punctuation_num,
label_punctuation_num=label_punctuation_num,
pred_sentence_punctuation_num=pred_sentence_punctuation_num,
label_setence_punctuation_num=label_setence_punctuation_num,
preds=preds,
label=query_data,
)