410 lines
15 KiB
Python
410 lines
15 KiB
Python
import logging
|
||
import os
|
||
import threading
|
||
import time
|
||
from typing import Dict, List, Optional
|
||
|
||
import requests
|
||
from flask import Flask, abort, request
|
||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||
|
||
from schemas.dataset import QueryData
|
||
from schemas.stream import StreamDataModel
|
||
from utils.evaluator_plus import evaluate_editops, evaluate_punctuation
|
||
|
||
from .logger import log
|
||
|
||
MY_POD_IP = os.environ["MY_POD_IP"]
|
||
|
||
|
||
class StopException(Exception): ...
|
||
|
||
|
||
class EvaluateResult(BaseModel):
|
||
lang: str
|
||
cer: float
|
||
align_start: Dict[int, int] = Field(
|
||
description="句首字对齐时间差值(ms) -> 对齐数"
|
||
)
|
||
align_end: Dict[int, int] = Field(
|
||
description="句尾字对齐时间差值(ms) -> 对齐数"
|
||
)
|
||
first_word_distance_sum: float = Field(description="句首字距离总和(s)")
|
||
last_word_distance_sum: float = Field(description="句尾字距离总和(s)")
|
||
rtf: float = Field(description="翻译速度")
|
||
first_receive_delay: float = Field(description="首包接收延迟(s)")
|
||
query_count: int = Field(description="音频数")
|
||
voice_count: int = Field(description="句子数")
|
||
pred_punctuation_num: int = Field(description="预测标点数")
|
||
label_punctuation_num: int = Field(description="标注标点数")
|
||
pred_sentence_punctuation_num: int = Field(description="预测句子标点数")
|
||
label_setence_punctuation_num: int = Field(description="标注句子标点数")
|
||
preds: List[StreamDataModel] = Field(description="预测结果")
|
||
label: QueryData = Field(description="标注结果")
|
||
|
||
|
||
class ResultModel(BaseModel):
|
||
taskId: str
|
||
status: str
|
||
message: str = Field("")
|
||
recognition_results: Optional[StreamDataModel] = Field(None)
|
||
|
||
@field_validator("recognition_results", mode="after")
|
||
def convert_to_seconds(cls, v: Optional[StreamDataModel], values):
|
||
# 在这里处理除以1000的逻辑
|
||
if v is None:
|
||
return v
|
||
v.end_time = v.end_time / 1000
|
||
v.start_time = v.start_time / 1000
|
||
for word in v.words:
|
||
word.start_time /= 1000
|
||
word.end_time /= 1000
|
||
return v
|
||
|
||
|
||
class ClientCallback:
|
||
def __init__(self, sut_url: str, port: int):
|
||
self.sut_url = sut_url #sut_url:ASR 服务的 URL(如 http://asr-service:8080)
|
||
self.port = port #port:当前客户端监听的端口(用于接收回调)
|
||
|
||
#创建 Flask 应用并注册路由
|
||
self.app = Flask(__name__)
|
||
self.app.add_url_rule(
|
||
"/api/asr/batch-callback/<taskId>",
|
||
view_func=self.asr_callback,
|
||
methods=["POST"],
|
||
)
|
||
self.app.add_url_rule(
|
||
"/api/asr-runner/report",
|
||
view_func=self.heartbeat,
|
||
methods=["POST"],
|
||
)
|
||
"""
|
||
路由 1:/api/asr/batch-callback/<taskId>
|
||
接收 ASR 服务的识别结果回调(self.asr_callback 处理)。
|
||
taskId 是路径参数,用于标识具体任务。
|
||
路由 2:/api/asr-runner/report
|
||
接收 ASR 服务的心跳检测请求(self.heartbeat 处理)。
|
||
"""
|
||
|
||
logging.getLogger("werkzeug").disabled = True
|
||
threading.Thread(
|
||
target=self.app.run, args=("0.0.0.0", port), daemon=True
|
||
).start()
|
||
self.mutex = threading.Lock()
|
||
self.finished = threading.Event()
|
||
self.product_avaiable = True
|
||
|
||
self.reset()
|
||
|
||
def reset(self):
|
||
self.begin_time = None
|
||
self.end_time = None
|
||
self.first_receive_time = None
|
||
self.last_heartbeat_time = None
|
||
self.app_on = False
|
||
self.para_seq = 0
|
||
self.finished.clear()
|
||
self.error: Optional[str] = None
|
||
self.last_recognition_result: Optional[StreamDataModel] = None
|
||
self.recognition_results: List[StreamDataModel] = []
|
||
|
||
def asr_callback(self, taskId: str):
|
||
if self.app_on is False:
|
||
abort(400)
|
||
body = request.get_json(silent=True) # 静默解析JSON,失败时返回None
|
||
if body is None:
|
||
abort(404)
|
||
try:
|
||
result = ResultModel.model_validate(body) #将回调的 JSON 数据解析为 ResultModel 对象,确保结构符合预期。
|
||
except ValidationError as e:
|
||
log.error("asr_callback: 结果格式错误: %s", e)
|
||
abort(404)
|
||
|
||
#处理任务完成状态(FINISHED)
|
||
if result.status == "FINISHED":
|
||
with self.mutex:
|
||
self.stop()
|
||
return "ok"
|
||
#处理非运行状态(非 RUNNING)
|
||
if result.status != "RUNNING":
|
||
log.error(
|
||
"asr_callback: 结果状态错误: %s, message: %s",
|
||
result.status,
|
||
result.message,
|
||
)
|
||
abort(404)
|
||
|
||
recognition_result = result.recognition_results
|
||
if recognition_result is None:
|
||
log.error("asr_callback: 结果中没有recognition_results字段")
|
||
abort(404)
|
||
|
||
with self.mutex:
|
||
if not self.app_on:
|
||
log.error("asr_callback: 应用已结束")
|
||
abort(400)
|
||
|
||
if recognition_result.para_seq < self.para_seq:
|
||
error = "asr_callback: 结果中para_seq小于上一次的: %d < %d" % (
|
||
recognition_result.para_seq,
|
||
self.para_seq,
|
||
)
|
||
log.error(error)
|
||
if self.error is None:
|
||
self.error = error
|
||
self.stop()
|
||
abort(404)
|
||
if recognition_result.para_seq > self.para_seq + 1:
|
||
error = (
|
||
"asr_callback: 结果中para_seq大于上一次的+1 \
|
||
说明存在para_seq = %d没有final_result为True确认"
|
||
% (self.para_seq + 1,)
|
||
)
|
||
log.error(error)
|
||
if self.error is None:
|
||
self.error = error
|
||
self.stop()
|
||
abort(404)
|
||
if (
|
||
self.last_recognition_result is not None
|
||
and recognition_result.start_time
|
||
< self.last_recognition_result.end_time
|
||
):
|
||
error = "asr_callback: 结果中start_time小于上一次的end_time: %s < %s" % (
|
||
recognition_result.start_time,
|
||
self.last_recognition_result.end_time,
|
||
)
|
||
log.error(error)
|
||
if self.error is None:
|
||
self.error = error
|
||
self.stop()
|
||
abort(404)
|
||
|
||
self.recognition_results.append(recognition_result)
|
||
if recognition_result.final_result is True:
|
||
self.para_seq = recognition_result.para_seq
|
||
if self.last_recognition_result is None:
|
||
self.first_receive_time = time.time()
|
||
self.last_recognition_result = recognition_result
|
||
|
||
return "ok"
|
||
|
||
"""
|
||
def heartbeat(self):
|
||
if self.app_on is False:
|
||
abort(400)
|
||
body = request.get_json(silent=True)
|
||
if body is None:
|
||
abort(404)
|
||
status = body.get("status")
|
||
if status != "RUNNING":
|
||
message = body.get("message", "")
|
||
if message:
|
||
message = ", message: " + message
|
||
log.error("heartbeat: 状态错误: %s%s", status, message)
|
||
return "ok"
|
||
|
||
with self.mutex:
|
||
self.last_heartbeat_time = time.time()
|
||
return "ok"
|
||
|
||
"""
|
||
|
||
def predict(
|
||
self,
|
||
language: Optional[str],
|
||
audio_file: str,
|
||
audio_duration: float,
|
||
task_id: str,
|
||
):
|
||
#使用互斥锁确保线程安全
|
||
with self.mutex:
|
||
if self.app_on:
|
||
log.error("上一音频尚未完成处理,流程出现异常")
|
||
raise StopException()
|
||
self.reset()
|
||
self.app_on = True
|
||
|
||
#请求URL:self.sut_url + "/predict"(如 http://localhost:8080/predict)
|
||
resp = requests.post(
|
||
self.sut_url + "/predict",
|
||
data={
|
||
"language": language,
|
||
"taskId": task_id,
|
||
"progressCallbackUrl": "http://%s:%d/api/asr/batch-callback/%s"
|
||
% (MY_POD_IP, self.port, task_id),
|
||
"heartbeatUrl": "http://%s:%d/api/asr-runner/report" % (MY_POD_IP, self.port),
|
||
},
|
||
files={"file": (audio_file, open(audio_file, "rb"))},
|
||
timeout=60,
|
||
)
|
||
|
||
#响应处理
|
||
if resp.status_code != 200:
|
||
log.error("/predict接口返回http code %s", resp.status_code)
|
||
raise StopException()
|
||
resp.raise_for_status()
|
||
|
||
status = resp.json().get("status")
|
||
if status != "OK":
|
||
log.error("/predict接口返回非OK状态: %s", status)
|
||
raise StopException()
|
||
#辅助线程
|
||
threading.Thread(
|
||
target=self.dead_line_check, args=(audio_duration,), daemon=True
|
||
).start()
|
||
threading.Thread(target=self.heartbeat_check, daemon=True).start()
|
||
|
||
def dead_line_check(self, audio_duration: float):
|
||
begin_time = time.time()
|
||
self.begin_time = begin_time
|
||
|
||
# 初始化 10s 延迟检测
|
||
self.sleep_to(begin_time + 10)
|
||
with self.mutex:
|
||
if self.last_recognition_result is None:
|
||
error = "首包延迟内未收到返回"
|
||
log.error(error)
|
||
if self.error is None:
|
||
self.error = error
|
||
self.stop()
|
||
return
|
||
|
||
# 第一次30s检测
|
||
next_checktime = begin_time + 30
|
||
ddl = begin_time + max((audio_duration / 6) + 10, 30)
|
||
while time.time() < ddl:
|
||
self.sleep_to(next_checktime)
|
||
with self.mutex:
|
||
if self.finished.is_set():
|
||
return
|
||
if self.last_recognition_result is None:
|
||
error = "检测追赶线过程中获取最后一次识别结果异常"
|
||
log.error(error)
|
||
if self.error is None:
|
||
self.error = error
|
||
self.stop()
|
||
return
|
||
last_end_time = self.last_recognition_result.end_time
|
||
expect_end_time = (next_checktime - begin_time - 30) * 5.4
|
||
if last_end_time < expect_end_time:
|
||
log.warning(
|
||
"识别时间位置 %s 被死亡追赶线 %s 已追上,将置为产品不可用",
|
||
last_end_time,
|
||
expect_end_time,
|
||
)
|
||
self.product_avaiable = False
|
||
self.sleep_to(ddl)
|
||
break
|
||
next_checktime = last_end_time / 5.4 + begin_time + 30 + 1
|
||
next_checktime = min(next_checktime, ddl)
|
||
with self.mutex:
|
||
if self.finished.is_set():
|
||
return
|
||
|
||
log.warning("识别速度rtf低于1/6, 将置为产品不可用")
|
||
self.product_avaiable = False
|
||
self.sleep_to(begin_time + max((audio_duration / 3) + 10, 30))
|
||
with self.mutex:
|
||
if self.finished.is_set():
|
||
return
|
||
error = "处理时间超过ddl %s " % (ddl - begin_time)
|
||
log.error(error)
|
||
if self.error is None:
|
||
self.error = error
|
||
self.stop()
|
||
return
|
||
|
||
def heartbeat_check(self):
|
||
self.last_heartbeat_time = time.time()
|
||
while True:
|
||
with self.mutex:
|
||
if self.finished.is_set():
|
||
return
|
||
if time.time() - self.last_heartbeat_time > 30:
|
||
error = "asr_runner 心跳超时 %s" % (
|
||
time.time() - self.last_heartbeat_time
|
||
)
|
||
log.error(error)
|
||
if self.error is None:
|
||
self.error = error
|
||
self.stop()
|
||
return
|
||
time.sleep(5)
|
||
|
||
def sleep_to(self, to: float):
|
||
seconds = to - time.time()
|
||
if seconds <= 0:
|
||
return
|
||
time.sleep(seconds)
|
||
|
||
def stop(self):
|
||
self.end_time = time.time()
|
||
self.finished.set()
|
||
self.app_on = False
|
||
|
||
def evaluate(self, query_data: QueryData):
|
||
log.info("开始评估")
|
||
if (
|
||
self.begin_time is None
|
||
or self.end_time is None
|
||
or self.first_receive_time is None
|
||
):
|
||
if self.begin_time is None:
|
||
log.error("评估流程异常 无开始时间")
|
||
if self.end_time is None:
|
||
log.error("评估流程异常 无结束时间")
|
||
if self.first_receive_time is None:
|
||
log.error("评估流程异常 无首次接收时间")
|
||
raise StopException()
|
||
rtf = max(self.end_time - self.begin_time - 10, 0) / query_data.duration
|
||
first_receive_delay = max(self.first_receive_time - self.begin_time, 0)
|
||
query_count = 1
|
||
voice_count = len(query_data.voice)
|
||
preds = self.recognition_results
|
||
self.recognition_results = list(
|
||
filter(lambda x: x.final_result, self.recognition_results)
|
||
)
|
||
(
|
||
pred_punctuation_num,
|
||
label_punctuation_num,
|
||
pred_sentence_punctuation_num,
|
||
label_setence_punctuation_num,
|
||
) = evaluate_punctuation(query_data, self.recognition_results)
|
||
|
||
(
|
||
cer,
|
||
_,
|
||
align_start,
|
||
align_end,
|
||
first_word_distance_sum,
|
||
last_word_distance_sum,
|
||
) = evaluate_editops(query_data, self.recognition_results)
|
||
|
||
if align_start[300] / voice_count < 0.8:
|
||
log.warning(
|
||
"评估结果首字300ms对齐率 %s < 0.8, 将置为产品不可用",
|
||
align_start[300] / voice_count,
|
||
)
|
||
self.product_avaiable = False
|
||
|
||
return EvaluateResult(
|
||
lang=query_data.lang,
|
||
cer=cer,
|
||
align_start=align_start,
|
||
align_end=align_end,
|
||
first_word_distance_sum=first_word_distance_sum,
|
||
last_word_distance_sum=last_word_distance_sum,
|
||
rtf=rtf,
|
||
first_receive_delay=first_receive_delay,
|
||
query_count=query_count,
|
||
voice_count=voice_count,
|
||
pred_punctuation_num=pred_punctuation_num,
|
||
label_punctuation_num=label_punctuation_num,
|
||
pred_sentence_punctuation_num=pred_sentence_punctuation_num,
|
||
label_setence_punctuation_num=label_setence_punctuation_num,
|
||
preds=preds,
|
||
label=query_data,
|
||
)
|