314 lines
9.5 KiB
Python
314 lines
9.5 KiB
Python
import logging
|
|
import os
|
|
import threading
|
|
import time
|
|
from typing import Optional
|
|
|
|
import flask
|
|
import requests
|
|
from werkzeug.datastructures import FileStorage
|
|
|
|
app = flask.Flask(__name__)
|
|
heartbeat_active = False
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
log.propagate = False
|
|
|
|
level = logging.INFO
|
|
|
|
log.setLevel(level)
|
|
|
|
formatter = logging.Formatter(
|
|
"[%(asctime)s] %(levelname)s : %(pathname)s:%(lineno)d - %(message)s",
|
|
"%Y-%m-%d %H:%M:%S",
|
|
)
|
|
|
|
streamHandler = logging.StreamHandler()
|
|
streamHandler.setLevel(level)
|
|
streamHandler.setFormatter(formatter)
|
|
log.addHandler(streamHandler)
|
|
|
|
|
|
def heartbeat(url):
|
|
global heartbeat_active
|
|
if heartbeat_active:
|
|
return
|
|
heartbeat_active = True
|
|
while True:
|
|
try:
|
|
requests.post(url, json={"status": "RUNNING"})
|
|
except Exception:
|
|
pass
|
|
time.sleep(10)
|
|
|
|
|
|
def asr(
|
|
audio_file: FileStorage,
|
|
language: Optional[str],
|
|
progressCallbackUrl: str,
|
|
taskId: str,
|
|
):
|
|
"""TODO: 读取audio_file, 调用语音识别服务, 实时返回识别结果"""
|
|
|
|
# ignore BEGIN
|
|
# 此处为榜单本地测试使用
|
|
if os.getenv("LOCAL_TEST"):
|
|
return local_test(progressCallbackUrl, taskId)
|
|
# ignore END
|
|
|
|
language = "de"
|
|
# 某一次识别返回
|
|
requests.post(
|
|
progressCallbackUrl,
|
|
json={
|
|
"taskId": taskId,
|
|
"status": "RUNNING",
|
|
"recognition_results": { # 传增量结果, status如果是FINISHED, 或者ERROR, 这个字段请不要传值
|
|
"text": "最先启动的还是",
|
|
"final_result": True,
|
|
"para_seq": 0,
|
|
"language": language,
|
|
"start_time": 6300,
|
|
"end_time": 6421,
|
|
"words": [
|
|
{
|
|
"text": "最",
|
|
"start_time": 6300,
|
|
"end_time": 6321,
|
|
},
|
|
{
|
|
"text": "先",
|
|
"start_time": 6321,
|
|
"end_time": 6345,
|
|
},
|
|
{
|
|
"text": "启",
|
|
"start_time": 6345,
|
|
"end_time": 6350,
|
|
},
|
|
{
|
|
"text": "动",
|
|
"start_time": 6350,
|
|
"end_time": 6370,
|
|
},
|
|
{
|
|
"text": "的",
|
|
"start_time": 6370,
|
|
"end_time": 6386,
|
|
},
|
|
{
|
|
"text": "还",
|
|
"start_time": 6386,
|
|
"end_time": 6421,
|
|
},
|
|
{
|
|
"text": "是",
|
|
"start_time": 6421,
|
|
"end_time": 6435,
|
|
},
|
|
],
|
|
},
|
|
},
|
|
)
|
|
# ... 识别结果返回完毕
|
|
|
|
# 识别结束
|
|
requests.post(
|
|
progressCallbackUrl,
|
|
json={
|
|
"taskId": taskId,
|
|
"status": "FINISHED",
|
|
},
|
|
)
|
|
|
|
|
|
@app.post("/predict")
|
|
def predict():
|
|
body = flask.request.form
|
|
language = body.get("language")
|
|
if language is None:
|
|
"自行判断语种"
|
|
taskId = body["taskId"]
|
|
progressCallbackUrl = body["progressCallbackUrl"]
|
|
heartbeatUrl = body["heartbeatUrl"]
|
|
|
|
threading.Thread(
|
|
target=heartbeat, args=(heartbeatUrl,), daemon=True
|
|
).start()
|
|
|
|
audio_file = flask.request.files["file"]
|
|
# audio_file.stream # 读取文件流
|
|
# audio_file.save("audio.mp3") # 保存文件
|
|
threading.Thread(
|
|
target=asr,
|
|
args=(audio_file, language, progressCallbackUrl, taskId),
|
|
daemon=True,
|
|
).start()
|
|
return flask.jsonify({"status": "OK"})
|
|
|
|
|
|
# ignore BEGIN
|
|
def local_test(progressCallbackUrl: str, taskId: str):
|
|
"""忽略此方法, 此方法为榜单本地调试使用"""
|
|
import random
|
|
import re
|
|
|
|
import yaml
|
|
|
|
def callback(content):
|
|
try:
|
|
if content is None:
|
|
requests.post(
|
|
progressCallbackUrl,
|
|
json={"taskId": taskId, "status": "FINISHED"},
|
|
)
|
|
else:
|
|
requests.post(
|
|
progressCallbackUrl,
|
|
json={
|
|
"taskId": taskId,
|
|
"status": "RUNNING",
|
|
"recognition_results": content,
|
|
},
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
with open(
|
|
os.getenv("LOCAL_TEST_DATA_PATH", "../dataset/out/data.yaml")
|
|
) as f:
|
|
data = yaml.full_load(f)
|
|
|
|
voices = data["query_data"][0]["voice"]
|
|
|
|
# 首次发送
|
|
first_send_time = random.randint(3, 5)
|
|
send_interval = random.random() * 0
|
|
log.info("首次发送%ss 发送间隔%ss" % (first_send_time, send_interval))
|
|
time.sleep(first_send_time)
|
|
|
|
# 将句子拼接到一起
|
|
if random.random() < 0.3:
|
|
log.info("将部分句子合并成单句 每次合并的句子不超过3句")
|
|
rand_idx = 0
|
|
rand_sep = [0, len(voices) - 1]
|
|
while rand_sep[rand_idx] + 1 <= rand_sep[rand_idx + 1] - 1:
|
|
rand_cursep = random.randint(
|
|
rand_sep[rand_idx] + 1,
|
|
min(rand_sep[rand_idx + 1] - 1, rand_sep[rand_idx] + 1 + 3),
|
|
)
|
|
rand_sep.insert(rand_idx + 1, rand_cursep)
|
|
rand_idx += 1
|
|
merged_voices = []
|
|
for i, cur_sep in enumerate(rand_sep[:-1]):
|
|
voice = voices[cur_sep]
|
|
for j in range(cur_sep + 1, rand_sep[i + 1]):
|
|
voice["answer"] += voices[j]["answer"]
|
|
voice["end"] = voices[j]["end"]
|
|
merged_voices.append(voice)
|
|
merged_voices.append(voices[rand_sep[-1]])
|
|
voices = merged_voices
|
|
|
|
def split_and_keep(text, delimiters):
|
|
# 构建正则表达式模式,匹配文本或分隔符
|
|
pattern = "|".join(re.escape(delimiter) for delimiter in delimiters)
|
|
pattern = f"(?:[^{pattern}]+|[{pattern}])"
|
|
return re.findall(pattern, text)
|
|
|
|
puncs = [",", ".", "?", "!", ";", ":"]
|
|
|
|
para_seq = 0
|
|
for voice in voices:
|
|
answer: str = voice["answer"]
|
|
start_time: float = voice["start"]
|
|
end_time: float = voice["end"]
|
|
words = split_and_keep(answer, puncs)
|
|
temp_words = []
|
|
for i, word in enumerate(words):
|
|
if i > 0 and i < len(words) - 1 and random.random() < 0.15:
|
|
log.info("随机删除word")
|
|
continue
|
|
temp_words.extend(word.split(" "))
|
|
if len(temp_words) == 0:
|
|
temp_words = words[0].split(" ")
|
|
words = temp_words
|
|
answer = " ".join(words)
|
|
words = list(map(lambda x: x.strip(), words))
|
|
words = list(filter(lambda x: len(x) > 0, words))
|
|
|
|
# 将时间均匀分配到每个字上
|
|
words_withtime = []
|
|
word_unittime = (end_time - start_time) / len(words)
|
|
for i, word in enumerate(words):
|
|
word_start = start_time + word_unittime * i
|
|
word_end = word_start + word_unittime
|
|
words_withtime.append(
|
|
{
|
|
"text": word,
|
|
"start_time": word_start * 1000,
|
|
"end_time": word_end * 1000,
|
|
}
|
|
)
|
|
|
|
# 将句子首尾的标点符号时间扩展到字上 标点符号时间为瞬间
|
|
punc_at = 0
|
|
while punc_at < len(words) and words[punc_at] in puncs:
|
|
punc_at += 1
|
|
if punc_at < len(words):
|
|
words_withtime[punc_at]["start_time"] = words_withtime[0][
|
|
"start_time"
|
|
]
|
|
for i in range(0, punc_at):
|
|
words_withtime[i]["start_time"] = words_withtime[0]["start_time"]
|
|
words_withtime[i]["end_time"] = words_withtime[0]["start_time"]
|
|
punc_at = len(words) - 1
|
|
while punc_at >= 0 and words[punc_at] in puncs:
|
|
punc_at -= 1
|
|
if punc_at >= 0:
|
|
words_withtime[punc_at]["end_time"] = words_withtime[-1]["end_time"]
|
|
for i in range(punc_at + 1, len(words)):
|
|
words_withtime[i]["start_time"] = (
|
|
words_withtime[-1]["end_time"] + 0.1
|
|
)
|
|
words_withtime[i]["end_time"] = words_withtime[-1]["end_time"] + 0.1
|
|
|
|
if random.random() < 0.4 and len(words_withtime) > 1:
|
|
log.info("发送一次final_result=False")
|
|
rand_idx = random.randint(1, len(words_withtime) - 1)
|
|
recognition_result = {
|
|
"text": " ".join(
|
|
map(lambda x: x["text"], words_withtime[:rand_idx])
|
|
),
|
|
"final_result": False,
|
|
"para_seq": para_seq,
|
|
"language": "de",
|
|
"start_time": start_time * 1000,
|
|
"end_time": end_time * 1000,
|
|
"words": words_withtime[:rand_idx],
|
|
}
|
|
callback(recognition_result)
|
|
|
|
recognition_result = {
|
|
"text": answer,
|
|
"final_result": True,
|
|
"para_seq": para_seq,
|
|
"language": "de",
|
|
"start_time": start_time * 1000,
|
|
"end_time": end_time * 1000,
|
|
"words": words_withtime,
|
|
}
|
|
callback(recognition_result)
|
|
para_seq += 1
|
|
log.info("send %s" % para_seq)
|
|
|
|
time.sleep(send_interval)
|
|
|
|
callback(None)
|
|
|
|
|
|
# ignore END
|
|
|
|
if __name__ == "__main__":
|
|
app.run(host="0.0.0.0", port=80)
|