410 lines
15 KiB
Python
410 lines
15 KiB
Python
|
|
import logging
|
|||
|
|
import os
|
|||
|
|
import threading
|
|||
|
|
import time
|
|||
|
|
from typing import Dict, List, Optional
|
|||
|
|
|
|||
|
|
import requests
|
|||
|
|
from flask import Flask, abort, request
|
|||
|
|
from pydantic import BaseModel, Field, ValidationError, field_validator
|
|||
|
|
|
|||
|
|
from schemas.dataset import QueryData
|
|||
|
|
from schemas.stream import StreamDataModel
|
|||
|
|
from utils.evaluator_plus import evaluate_editops, evaluate_punctuation
|
|||
|
|
|
|||
|
|
from .logger import log
|
|||
|
|
|
|||
|
|
MY_POD_IP = os.environ["MY_POD_IP"]
|
|||
|
|
|
|||
|
|
|
|||
|
|
class StopException(Exception): ...
|
|||
|
|
|
|||
|
|
|
|||
|
|
class EvaluateResult(BaseModel):
|
|||
|
|
lang: str
|
|||
|
|
cer: float
|
|||
|
|
align_start: Dict[int, int] = Field(
|
|||
|
|
description="句首字对齐时间差值(ms) -> 对齐数"
|
|||
|
|
)
|
|||
|
|
align_end: Dict[int, int] = Field(
|
|||
|
|
description="句尾字对齐时间差值(ms) -> 对齐数"
|
|||
|
|
)
|
|||
|
|
first_word_distance_sum: float = Field(description="句首字距离总和(s)")
|
|||
|
|
last_word_distance_sum: float = Field(description="句尾字距离总和(s)")
|
|||
|
|
rtf: float = Field(description="翻译速度")
|
|||
|
|
first_receive_delay: float = Field(description="首包接收延迟(s)")
|
|||
|
|
query_count: int = Field(description="音频数")
|
|||
|
|
voice_count: int = Field(description="句子数")
|
|||
|
|
pred_punctuation_num: int = Field(description="预测标点数")
|
|||
|
|
label_punctuation_num: int = Field(description="标注标点数")
|
|||
|
|
pred_sentence_punctuation_num: int = Field(description="预测句子标点数")
|
|||
|
|
label_setence_punctuation_num: int = Field(description="标注句子标点数")
|
|||
|
|
preds: List[StreamDataModel] = Field(description="预测结果")
|
|||
|
|
label: QueryData = Field(description="标注结果")
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ResultModel(BaseModel):
|
|||
|
|
taskId: str
|
|||
|
|
status: str
|
|||
|
|
message: str = Field("")
|
|||
|
|
recognition_results: Optional[StreamDataModel] = Field(None)
|
|||
|
|
|
|||
|
|
@field_validator("recognition_results", mode="after")
|
|||
|
|
def convert_to_seconds(cls, v: Optional[StreamDataModel], values):
|
|||
|
|
# 在这里处理除以1000的逻辑
|
|||
|
|
if v is None:
|
|||
|
|
return v
|
|||
|
|
v.end_time = v.end_time / 1000
|
|||
|
|
v.start_time = v.start_time / 1000
|
|||
|
|
for word in v.words:
|
|||
|
|
word.start_time /= 1000
|
|||
|
|
word.end_time /= 1000
|
|||
|
|
return v
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ClientCallback:
|
|||
|
|
def __init__(self, sut_url: str, port: int):
|
|||
|
|
self.sut_url = sut_url #sut_url:ASR 服务的 URL(如 http://asr-service:8080)
|
|||
|
|
self.port = port #port:当前客户端监听的端口(用于接收回调)
|
|||
|
|
|
|||
|
|
#创建 Flask 应用并注册路由
|
|||
|
|
self.app = Flask(__name__)
|
|||
|
|
self.app.add_url_rule(
|
|||
|
|
"/api/asr/batch-callback/<taskId>",
|
|||
|
|
view_func=self.asr_callback,
|
|||
|
|
methods=["POST"],
|
|||
|
|
)
|
|||
|
|
self.app.add_url_rule(
|
|||
|
|
"/api/asr-runner/report",
|
|||
|
|
view_func=self.heartbeat,
|
|||
|
|
methods=["POST"],
|
|||
|
|
)
|
|||
|
|
"""
|
|||
|
|
路由 1:/api/asr/batch-callback/<taskId>
|
|||
|
|
接收 ASR 服务的识别结果回调(self.asr_callback 处理)。
|
|||
|
|
taskId 是路径参数,用于标识具体任务。
|
|||
|
|
路由 2:/api/asr-runner/report
|
|||
|
|
接收 ASR 服务的心跳检测请求(self.heartbeat 处理)。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
logging.getLogger("werkzeug").disabled = True
|
|||
|
|
threading.Thread(
|
|||
|
|
target=self.app.run, args=("0.0.0.0", port), daemon=True
|
|||
|
|
).start()
|
|||
|
|
self.mutex = threading.Lock()
|
|||
|
|
self.finished = threading.Event()
|
|||
|
|
self.product_avaiable = True
|
|||
|
|
|
|||
|
|
self.reset()
|
|||
|
|
|
|||
|
|
def reset(self):
|
|||
|
|
self.begin_time = None
|
|||
|
|
self.end_time = None
|
|||
|
|
self.first_receive_time = None
|
|||
|
|
self.last_heartbeat_time = None
|
|||
|
|
self.app_on = False
|
|||
|
|
self.para_seq = 0
|
|||
|
|
self.finished.clear()
|
|||
|
|
self.error: Optional[str] = None
|
|||
|
|
self.last_recognition_result: Optional[StreamDataModel] = None
|
|||
|
|
self.recognition_results: List[StreamDataModel] = []
|
|||
|
|
|
|||
|
|
def asr_callback(self, taskId: str):
|
|||
|
|
if self.app_on is False:
|
|||
|
|
abort(400)
|
|||
|
|
body = request.get_json(silent=True) # 静默解析JSON,失败时返回None
|
|||
|
|
if body is None:
|
|||
|
|
abort(404)
|
|||
|
|
try:
|
|||
|
|
result = ResultModel.model_validate(body) #将回调的 JSON 数据解析为 ResultModel 对象,确保结构符合预期。
|
|||
|
|
except ValidationError as e:
|
|||
|
|
log.error("asr_callback: 结果格式错误: %s", e)
|
|||
|
|
abort(404)
|
|||
|
|
|
|||
|
|
#处理任务完成状态(FINISHED)
|
|||
|
|
if result.status == "FINISHED":
|
|||
|
|
with self.mutex:
|
|||
|
|
self.stop()
|
|||
|
|
return "ok"
|
|||
|
|
#处理非运行状态(非 RUNNING)
|
|||
|
|
if result.status != "RUNNING":
|
|||
|
|
log.error(
|
|||
|
|
"asr_callback: 结果状态错误: %s, message: %s",
|
|||
|
|
result.status,
|
|||
|
|
result.message,
|
|||
|
|
)
|
|||
|
|
abort(404)
|
|||
|
|
|
|||
|
|
recognition_result = result.recognition_results
|
|||
|
|
if recognition_result is None:
|
|||
|
|
log.error("asr_callback: 结果中没有recognition_results字段")
|
|||
|
|
abort(404)
|
|||
|
|
|
|||
|
|
with self.mutex:
|
|||
|
|
if not self.app_on:
|
|||
|
|
log.error("asr_callback: 应用已结束")
|
|||
|
|
abort(400)
|
|||
|
|
|
|||
|
|
if recognition_result.para_seq < self.para_seq:
|
|||
|
|
error = "asr_callback: 结果中para_seq小于上一次的: %d < %d" % (
|
|||
|
|
recognition_result.para_seq,
|
|||
|
|
self.para_seq,
|
|||
|
|
)
|
|||
|
|
log.error(error)
|
|||
|
|
if self.error is None:
|
|||
|
|
self.error = error
|
|||
|
|
self.stop()
|
|||
|
|
abort(404)
|
|||
|
|
if recognition_result.para_seq > self.para_seq + 1:
|
|||
|
|
error = (
|
|||
|
|
"asr_callback: 结果中para_seq大于上一次的+1 \
|
|||
|
|
说明存在para_seq = %d没有final_result为True确认"
|
|||
|
|
% (self.para_seq + 1,)
|
|||
|
|
)
|
|||
|
|
log.error(error)
|
|||
|
|
if self.error is None:
|
|||
|
|
self.error = error
|
|||
|
|
self.stop()
|
|||
|
|
abort(404)
|
|||
|
|
if (
|
|||
|
|
self.last_recognition_result is not None
|
|||
|
|
and recognition_result.start_time
|
|||
|
|
< self.last_recognition_result.end_time
|
|||
|
|
):
|
|||
|
|
error = "asr_callback: 结果中start_time小于上一次的end_time: %s < %s" % (
|
|||
|
|
recognition_result.start_time,
|
|||
|
|
self.last_recognition_result.end_time,
|
|||
|
|
)
|
|||
|
|
log.error(error)
|
|||
|
|
if self.error is None:
|
|||
|
|
self.error = error
|
|||
|
|
self.stop()
|
|||
|
|
abort(404)
|
|||
|
|
|
|||
|
|
self.recognition_results.append(recognition_result)
|
|||
|
|
if recognition_result.final_result is True:
|
|||
|
|
self.para_seq = recognition_result.para_seq
|
|||
|
|
if self.last_recognition_result is None:
|
|||
|
|
self.first_receive_time = time.time()
|
|||
|
|
self.last_recognition_result = recognition_result
|
|||
|
|
|
|||
|
|
return "ok"
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
def heartbeat(self):
|
|||
|
|
if self.app_on is False:
|
|||
|
|
abort(400)
|
|||
|
|
body = request.get_json(silent=True)
|
|||
|
|
if body is None:
|
|||
|
|
abort(404)
|
|||
|
|
status = body.get("status")
|
|||
|
|
if status != "RUNNING":
|
|||
|
|
message = body.get("message", "")
|
|||
|
|
if message:
|
|||
|
|
message = ", message: " + message
|
|||
|
|
log.error("heartbeat: 状态错误: %s%s", status, message)
|
|||
|
|
return "ok"
|
|||
|
|
|
|||
|
|
with self.mutex:
|
|||
|
|
self.last_heartbeat_time = time.time()
|
|||
|
|
return "ok"
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def predict(
|
|||
|
|
self,
|
|||
|
|
language: Optional[str],
|
|||
|
|
audio_file: str,
|
|||
|
|
audio_duration: float,
|
|||
|
|
task_id: str,
|
|||
|
|
):
|
|||
|
|
#使用互斥锁确保线程安全
|
|||
|
|
with self.mutex:
|
|||
|
|
if self.app_on:
|
|||
|
|
log.error("上一音频尚未完成处理,流程出现异常")
|
|||
|
|
raise StopException()
|
|||
|
|
self.reset()
|
|||
|
|
self.app_on = True
|
|||
|
|
|
|||
|
|
#请求URL:self.sut_url + "/predict"(如 http://localhost:8080/predict)
|
|||
|
|
resp = requests.post(
|
|||
|
|
self.sut_url + "/predict",
|
|||
|
|
data={
|
|||
|
|
"language": language,
|
|||
|
|
"taskId": task_id,
|
|||
|
|
"progressCallbackUrl": "http://%s:%d/api/asr/batch-callback/%s"
|
|||
|
|
% (MY_POD_IP, self.port, task_id),
|
|||
|
|
"heartbeatUrl": "http://%s:%d/api/asr-runner/report" % (MY_POD_IP, self.port),
|
|||
|
|
},
|
|||
|
|
files={"file": (audio_file, open(audio_file, "rb"))},
|
|||
|
|
timeout=60,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
#响应处理
|
|||
|
|
if resp.status_code != 200:
|
|||
|
|
log.error("/predict接口返回http code %s", resp.status_code)
|
|||
|
|
raise StopException()
|
|||
|
|
resp.raise_for_status()
|
|||
|
|
|
|||
|
|
status = resp.json().get("status")
|
|||
|
|
if status != "OK":
|
|||
|
|
log.error("/predict接口返回非OK状态: %s", status)
|
|||
|
|
raise StopException()
|
|||
|
|
#辅助线程
|
|||
|
|
threading.Thread(
|
|||
|
|
target=self.dead_line_check, args=(audio_duration,), daemon=True
|
|||
|
|
).start()
|
|||
|
|
threading.Thread(target=self.heartbeat_check, daemon=True).start()
|
|||
|
|
|
|||
|
|
def dead_line_check(self, audio_duration: float):
|
|||
|
|
begin_time = time.time()
|
|||
|
|
self.begin_time = begin_time
|
|||
|
|
|
|||
|
|
# 初始化 10s 延迟检测
|
|||
|
|
self.sleep_to(begin_time + 10)
|
|||
|
|
with self.mutex:
|
|||
|
|
if self.last_recognition_result is None:
|
|||
|
|
error = "首包延迟内未收到返回"
|
|||
|
|
log.error(error)
|
|||
|
|
if self.error is None:
|
|||
|
|
self.error = error
|
|||
|
|
self.stop()
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 第一次30s检测
|
|||
|
|
next_checktime = begin_time + 30
|
|||
|
|
ddl = begin_time + max((audio_duration / 6) + 10, 30)
|
|||
|
|
while time.time() < ddl:
|
|||
|
|
self.sleep_to(next_checktime)
|
|||
|
|
with self.mutex:
|
|||
|
|
if self.finished.is_set():
|
|||
|
|
return
|
|||
|
|
if self.last_recognition_result is None:
|
|||
|
|
error = "检测追赶线过程中获取最后一次识别结果异常"
|
|||
|
|
log.error(error)
|
|||
|
|
if self.error is None:
|
|||
|
|
self.error = error
|
|||
|
|
self.stop()
|
|||
|
|
return
|
|||
|
|
last_end_time = self.last_recognition_result.end_time
|
|||
|
|
expect_end_time = (next_checktime - begin_time - 30) * 5.4
|
|||
|
|
if last_end_time < expect_end_time:
|
|||
|
|
log.warning(
|
|||
|
|
"识别时间位置 %s 被死亡追赶线 %s 已追上,将置为产品不可用",
|
|||
|
|
last_end_time,
|
|||
|
|
expect_end_time,
|
|||
|
|
)
|
|||
|
|
self.product_avaiable = False
|
|||
|
|
self.sleep_to(ddl)
|
|||
|
|
break
|
|||
|
|
next_checktime = last_end_time / 5.4 + begin_time + 30 + 1
|
|||
|
|
next_checktime = min(next_checktime, ddl)
|
|||
|
|
with self.mutex:
|
|||
|
|
if self.finished.is_set():
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
log.warning("识别速度rtf低于1/6, 将置为产品不可用")
|
|||
|
|
self.product_avaiable = False
|
|||
|
|
self.sleep_to(begin_time + max((audio_duration / 3) + 10, 30))
|
|||
|
|
with self.mutex:
|
|||
|
|
if self.finished.is_set():
|
|||
|
|
return
|
|||
|
|
error = "处理时间超过ddl %s " % (ddl - begin_time)
|
|||
|
|
log.error(error)
|
|||
|
|
if self.error is None:
|
|||
|
|
self.error = error
|
|||
|
|
self.stop()
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
def heartbeat_check(self):
|
|||
|
|
self.last_heartbeat_time = time.time()
|
|||
|
|
while True:
|
|||
|
|
with self.mutex:
|
|||
|
|
if self.finished.is_set():
|
|||
|
|
return
|
|||
|
|
if time.time() - self.last_heartbeat_time > 30:
|
|||
|
|
error = "asr_runner 心跳超时 %s" % (
|
|||
|
|
time.time() - self.last_heartbeat_time
|
|||
|
|
)
|
|||
|
|
log.error(error)
|
|||
|
|
if self.error is None:
|
|||
|
|
self.error = error
|
|||
|
|
self.stop()
|
|||
|
|
return
|
|||
|
|
time.sleep(5)
|
|||
|
|
|
|||
|
|
def sleep_to(self, to: float):
|
|||
|
|
seconds = to - time.time()
|
|||
|
|
if seconds <= 0:
|
|||
|
|
return
|
|||
|
|
time.sleep(seconds)
|
|||
|
|
|
|||
|
|
def stop(self):
|
|||
|
|
self.end_time = time.time()
|
|||
|
|
self.finished.set()
|
|||
|
|
self.app_on = False
|
|||
|
|
|
|||
|
|
def evaluate(self, query_data: QueryData):
|
|||
|
|
log.info("开始评估")
|
|||
|
|
if (
|
|||
|
|
self.begin_time is None
|
|||
|
|
or self.end_time is None
|
|||
|
|
or self.first_receive_time is None
|
|||
|
|
):
|
|||
|
|
if self.begin_time is None:
|
|||
|
|
log.error("评估流程异常 无开始时间")
|
|||
|
|
if self.end_time is None:
|
|||
|
|
log.error("评估流程异常 无结束时间")
|
|||
|
|
if self.first_receive_time is None:
|
|||
|
|
log.error("评估流程异常 无首次接收时间")
|
|||
|
|
raise StopException()
|
|||
|
|
rtf = max(self.end_time - self.begin_time - 10, 0) / query_data.duration
|
|||
|
|
first_receive_delay = max(self.first_receive_time - self.begin_time, 0)
|
|||
|
|
query_count = 1
|
|||
|
|
voice_count = len(query_data.voice)
|
|||
|
|
preds = self.recognition_results
|
|||
|
|
self.recognition_results = list(
|
|||
|
|
filter(lambda x: x.final_result, self.recognition_results)
|
|||
|
|
)
|
|||
|
|
(
|
|||
|
|
pred_punctuation_num,
|
|||
|
|
label_punctuation_num,
|
|||
|
|
pred_sentence_punctuation_num,
|
|||
|
|
label_setence_punctuation_num,
|
|||
|
|
) = evaluate_punctuation(query_data, self.recognition_results)
|
|||
|
|
|
|||
|
|
(
|
|||
|
|
cer,
|
|||
|
|
_,
|
|||
|
|
align_start,
|
|||
|
|
align_end,
|
|||
|
|
first_word_distance_sum,
|
|||
|
|
last_word_distance_sum,
|
|||
|
|
) = evaluate_editops(query_data, self.recognition_results)
|
|||
|
|
|
|||
|
|
if align_start[300] / voice_count < 0.8:
|
|||
|
|
log.warning(
|
|||
|
|
"评估结果首字300ms对齐率 %s < 0.8, 将置为产品不可用",
|
|||
|
|
align_start[300] / voice_count,
|
|||
|
|
)
|
|||
|
|
self.product_avaiable = False
|
|||
|
|
|
|||
|
|
return EvaluateResult(
|
|||
|
|
lang=query_data.lang,
|
|||
|
|
cer=cer,
|
|||
|
|
align_start=align_start,
|
|||
|
|
align_end=align_end,
|
|||
|
|
first_word_distance_sum=first_word_distance_sum,
|
|||
|
|
last_word_distance_sum=last_word_distance_sum,
|
|||
|
|
rtf=rtf,
|
|||
|
|
first_receive_delay=first_receive_delay,
|
|||
|
|
query_count=query_count,
|
|||
|
|
voice_count=voice_count,
|
|||
|
|
pred_punctuation_num=pred_punctuation_num,
|
|||
|
|
label_punctuation_num=label_punctuation_num,
|
|||
|
|
pred_sentence_punctuation_num=pred_sentence_punctuation_num,
|
|||
|
|
label_setence_punctuation_num=label_setence_punctuation_num,
|
|||
|
|
preds=preds,
|
|||
|
|
label=query_data,
|
|||
|
|
)
|