import logging import os import threading import time from typing import Dict, List, Optional import requests from flask import Flask, abort, request from pydantic import BaseModel, Field, ValidationError, field_validator from schemas.dataset import QueryData from schemas.stream import StreamDataModel from utils.evaluator_plus import evaluate_editops, evaluate_punctuation from .logger import log MY_POD_IP = os.environ["MY_POD_IP"] class StopException(Exception): ... class EvaluateResult(BaseModel): lang: str cer: float align_start: Dict[int, int] = Field( description="句首字对齐时间差值(ms) -> 对齐数" ) align_end: Dict[int, int] = Field( description="句尾字对齐时间差值(ms) -> 对齐数" ) first_word_distance_sum: float = Field(description="句首字距离总和(s)") last_word_distance_sum: float = Field(description="句尾字距离总和(s)") rtf: float = Field(description="翻译速度") first_receive_delay: float = Field(description="首包接收延迟(s)") query_count: int = Field(description="音频数") voice_count: int = Field(description="句子数") pred_punctuation_num: int = Field(description="预测标点数") label_punctuation_num: int = Field(description="标注标点数") pred_sentence_punctuation_num: int = Field(description="预测句子标点数") label_setence_punctuation_num: int = Field(description="标注句子标点数") preds: List[StreamDataModel] = Field(description="预测结果") label: QueryData = Field(description="标注结果") class ResultModel(BaseModel): taskId: str status: str message: str = Field("") recognition_results: Optional[StreamDataModel] = Field(None) @field_validator("recognition_results", mode="after") def convert_to_seconds(cls, v: Optional[StreamDataModel], values): # 在这里处理除以1000的逻辑 if v is None: return v v.end_time = v.end_time / 1000 v.start_time = v.start_time / 1000 for word in v.words: word.start_time /= 1000 word.end_time /= 1000 return v class ClientCallback: def __init__(self, sut_url: str, port: int): self.sut_url = sut_url #sut_url:ASR 服务的 URL(如 http://asr-service:8080) self.port = port #port:当前客户端监听的端口(用于接收回调) #创建 Flask 应用并注册路由 self.app = Flask(__name__) self.app.add_url_rule( "/api/asr/batch-callback/", view_func=self.asr_callback, methods=["POST"], ) self.app.add_url_rule( "/api/asr-runner/report", view_func=self.heartbeat, methods=["POST"], ) """ 路由 1:/api/asr/batch-callback/ 接收 ASR 服务的识别结果回调(self.asr_callback 处理)。 taskId 是路径参数,用于标识具体任务。 路由 2:/api/asr-runner/report 接收 ASR 服务的心跳检测请求(self.heartbeat 处理)。 """ logging.getLogger("werkzeug").disabled = True threading.Thread( target=self.app.run, args=("0.0.0.0", port), daemon=True ).start() self.mutex = threading.Lock() self.finished = threading.Event() self.product_avaiable = True self.reset() def reset(self): self.begin_time = None self.end_time = None self.first_receive_time = None self.last_heartbeat_time = None self.app_on = False self.para_seq = 0 self.finished.clear() self.error: Optional[str] = None self.last_recognition_result: Optional[StreamDataModel] = None self.recognition_results: List[StreamDataModel] = [] def asr_callback(self, taskId: str): if self.app_on is False: abort(400) body = request.get_json(silent=True) # 静默解析JSON,失败时返回None if body is None: abort(404) try: result = ResultModel.model_validate(body) #将回调的 JSON 数据解析为 ResultModel 对象,确保结构符合预期。 except ValidationError as e: log.error("asr_callback: 结果格式错误: %s", e) abort(404) #处理任务完成状态(FINISHED) if result.status == "FINISHED": with self.mutex: self.stop() return "ok" #处理非运行状态(非 RUNNING) if result.status != "RUNNING": log.error( "asr_callback: 结果状态错误: %s, message: %s", result.status, result.message, ) abort(404) recognition_result = result.recognition_results if recognition_result is None: log.error("asr_callback: 结果中没有recognition_results字段") abort(404) with self.mutex: if not self.app_on: log.error("asr_callback: 应用已结束") abort(400) if recognition_result.para_seq < self.para_seq: error = "asr_callback: 结果中para_seq小于上一次的: %d < %d" % ( recognition_result.para_seq, self.para_seq, ) log.error(error) if self.error is None: self.error = error self.stop() abort(404) if recognition_result.para_seq > self.para_seq + 1: error = ( "asr_callback: 结果中para_seq大于上一次的+1 \ 说明存在para_seq = %d没有final_result为True确认" % (self.para_seq + 1,) ) log.error(error) if self.error is None: self.error = error self.stop() abort(404) if ( self.last_recognition_result is not None and recognition_result.start_time < self.last_recognition_result.end_time ): error = "asr_callback: 结果中start_time小于上一次的end_time: %s < %s" % ( recognition_result.start_time, self.last_recognition_result.end_time, ) log.error(error) if self.error is None: self.error = error self.stop() abort(404) self.recognition_results.append(recognition_result) if recognition_result.final_result is True: self.para_seq = recognition_result.para_seq if self.last_recognition_result is None: self.first_receive_time = time.time() self.last_recognition_result = recognition_result return "ok" """ def heartbeat(self): if self.app_on is False: abort(400) body = request.get_json(silent=True) if body is None: abort(404) status = body.get("status") if status != "RUNNING": message = body.get("message", "") if message: message = ", message: " + message log.error("heartbeat: 状态错误: %s%s", status, message) return "ok" with self.mutex: self.last_heartbeat_time = time.time() return "ok" """ def predict( self, language: Optional[str], audio_file: str, audio_duration: float, task_id: str, ): #使用互斥锁确保线程安全 with self.mutex: if self.app_on: log.error("上一音频尚未完成处理,流程出现异常") raise StopException() self.reset() self.app_on = True #请求URL:self.sut_url + "/predict"(如 http://localhost:8080/predict) resp = requests.post( self.sut_url + "/predict", data={ "language": language, "taskId": task_id, "progressCallbackUrl": "http://%s:%d/api/asr/batch-callback/%s" % (MY_POD_IP, self.port, task_id), "heartbeatUrl": "http://%s:%d/api/asr-runner/report" % (MY_POD_IP, self.port), }, files={"file": (audio_file, open(audio_file, "rb"))}, timeout=60, ) #响应处理 if resp.status_code != 200: log.error("/predict接口返回http code %s", resp.status_code) raise StopException() resp.raise_for_status() status = resp.json().get("status") if status != "OK": log.error("/predict接口返回非OK状态: %s", status) raise StopException() #辅助线程 threading.Thread( target=self.dead_line_check, args=(audio_duration,), daemon=True ).start() threading.Thread(target=self.heartbeat_check, daemon=True).start() def dead_line_check(self, audio_duration: float): begin_time = time.time() self.begin_time = begin_time # 初始化 10s 延迟检测 self.sleep_to(begin_time + 10) with self.mutex: if self.last_recognition_result is None: error = "首包延迟内未收到返回" log.error(error) if self.error is None: self.error = error self.stop() return # 第一次30s检测 next_checktime = begin_time + 30 ddl = begin_time + max((audio_duration / 6) + 10, 30) while time.time() < ddl: self.sleep_to(next_checktime) with self.mutex: if self.finished.is_set(): return if self.last_recognition_result is None: error = "检测追赶线过程中获取最后一次识别结果异常" log.error(error) if self.error is None: self.error = error self.stop() return last_end_time = self.last_recognition_result.end_time expect_end_time = (next_checktime - begin_time - 30) * 5.4 if last_end_time < expect_end_time: log.warning( "识别时间位置 %s 被死亡追赶线 %s 已追上,将置为产品不可用", last_end_time, expect_end_time, ) self.product_avaiable = False self.sleep_to(ddl) break next_checktime = last_end_time / 5.4 + begin_time + 30 + 1 next_checktime = min(next_checktime, ddl) with self.mutex: if self.finished.is_set(): return log.warning("识别速度rtf低于1/6, 将置为产品不可用") self.product_avaiable = False self.sleep_to(begin_time + max((audio_duration / 3) + 10, 30)) with self.mutex: if self.finished.is_set(): return error = "处理时间超过ddl %s " % (ddl - begin_time) log.error(error) if self.error is None: self.error = error self.stop() return def heartbeat_check(self): self.last_heartbeat_time = time.time() while True: with self.mutex: if self.finished.is_set(): return if time.time() - self.last_heartbeat_time > 30: error = "asr_runner 心跳超时 %s" % ( time.time() - self.last_heartbeat_time ) log.error(error) if self.error is None: self.error = error self.stop() return time.sleep(5) def sleep_to(self, to: float): seconds = to - time.time() if seconds <= 0: return time.sleep(seconds) def stop(self): self.end_time = time.time() self.finished.set() self.app_on = False def evaluate(self, query_data: QueryData): log.info("开始评估") if ( self.begin_time is None or self.end_time is None or self.first_receive_time is None ): if self.begin_time is None: log.error("评估流程异常 无开始时间") if self.end_time is None: log.error("评估流程异常 无结束时间") if self.first_receive_time is None: log.error("评估流程异常 无首次接收时间") raise StopException() rtf = max(self.end_time - self.begin_time - 10, 0) / query_data.duration first_receive_delay = max(self.first_receive_time - self.begin_time, 0) query_count = 1 voice_count = len(query_data.voice) preds = self.recognition_results self.recognition_results = list( filter(lambda x: x.final_result, self.recognition_results) ) ( pred_punctuation_num, label_punctuation_num, pred_sentence_punctuation_num, label_setence_punctuation_num, ) = evaluate_punctuation(query_data, self.recognition_results) ( cer, _, align_start, align_end, first_word_distance_sum, last_word_distance_sum, ) = evaluate_editops(query_data, self.recognition_results) if align_start[300] / voice_count < 0.8: log.warning( "评估结果首字300ms对齐率 %s < 0.8, 将置为产品不可用", align_start[300] / voice_count, ) self.product_avaiable = False return EvaluateResult( lang=query_data.lang, cer=cer, align_start=align_start, align_end=align_end, first_word_distance_sum=first_word_distance_sum, last_word_distance_sum=last_word_distance_sum, rtf=rtf, first_receive_delay=first_receive_delay, query_count=query_count, voice_count=voice_count, pred_punctuation_num=pred_punctuation_num, label_punctuation_num=label_punctuation_num, pred_sentence_punctuation_num=pred_sentence_punctuation_num, label_setence_punctuation_num=label_setence_punctuation_num, preds=preds, label=query_data, )