update
This commit is contained in:
313
starting_kit/main.py
Normal file
313
starting_kit/main.py
Normal file
@@ -0,0 +1,313 @@
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import flask
|
||||
import requests
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
app = flask.Flask(__name__)
|
||||
heartbeat_active = False
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
log.propagate = False
|
||||
|
||||
level = logging.INFO
|
||||
|
||||
log.setLevel(level)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
"[%(asctime)s] %(levelname)s : %(pathname)s:%(lineno)d - %(message)s",
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
streamHandler = logging.StreamHandler()
|
||||
streamHandler.setLevel(level)
|
||||
streamHandler.setFormatter(formatter)
|
||||
log.addHandler(streamHandler)
|
||||
|
||||
|
||||
def heartbeat(url):
|
||||
global heartbeat_active
|
||||
if heartbeat_active:
|
||||
return
|
||||
heartbeat_active = True
|
||||
while True:
|
||||
try:
|
||||
requests.post(url, json={"status": "RUNNING"})
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
def asr(
|
||||
audio_file: FileStorage,
|
||||
language: Optional[str],
|
||||
progressCallbackUrl: str,
|
||||
taskId: str,
|
||||
):
|
||||
"""TODO: 读取audio_file, 调用语音识别服务, 实时返回识别结果"""
|
||||
|
||||
# ignore BEGIN
|
||||
# 此处为榜单本地测试使用
|
||||
if os.getenv("LOCAL_TEST"):
|
||||
return local_test(progressCallbackUrl, taskId)
|
||||
# ignore END
|
||||
|
||||
language = "de"
|
||||
# 某一次识别返回
|
||||
requests.post(
|
||||
progressCallbackUrl,
|
||||
json={
|
||||
"taskId": taskId,
|
||||
"status": "RUNNING",
|
||||
"recognition_results": { # 传增量结果, status如果是FINISHED, 或者ERROR, 这个字段请不要传值
|
||||
"text": "最先启动的还是",
|
||||
"final_result": True,
|
||||
"para_seq": 0,
|
||||
"language": language,
|
||||
"start_time": 6300,
|
||||
"end_time": 6421,
|
||||
"words": [
|
||||
{
|
||||
"text": "最",
|
||||
"start_time": 6300,
|
||||
"end_time": 6321,
|
||||
},
|
||||
{
|
||||
"text": "先",
|
||||
"start_time": 6321,
|
||||
"end_time": 6345,
|
||||
},
|
||||
{
|
||||
"text": "启",
|
||||
"start_time": 6345,
|
||||
"end_time": 6350,
|
||||
},
|
||||
{
|
||||
"text": "动",
|
||||
"start_time": 6350,
|
||||
"end_time": 6370,
|
||||
},
|
||||
{
|
||||
"text": "的",
|
||||
"start_time": 6370,
|
||||
"end_time": 6386,
|
||||
},
|
||||
{
|
||||
"text": "还",
|
||||
"start_time": 6386,
|
||||
"end_time": 6421,
|
||||
},
|
||||
{
|
||||
"text": "是",
|
||||
"start_time": 6421,
|
||||
"end_time": 6435,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
# ... 识别结果返回完毕
|
||||
|
||||
# 识别结束
|
||||
requests.post(
|
||||
progressCallbackUrl,
|
||||
json={
|
||||
"taskId": taskId,
|
||||
"status": "FINISHED",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@app.post("/predict")
|
||||
def predict():
|
||||
body = flask.request.form
|
||||
language = body.get("language")
|
||||
if language is None:
|
||||
"自行判断语种"
|
||||
taskId = body["taskId"]
|
||||
progressCallbackUrl = body["progressCallbackUrl"]
|
||||
heartbeatUrl = body["heartbeatUrl"]
|
||||
|
||||
threading.Thread(
|
||||
target=heartbeat, args=(heartbeatUrl,), daemon=True
|
||||
).start()
|
||||
|
||||
audio_file = flask.request.files["file"]
|
||||
# audio_file.stream # 读取文件流
|
||||
# audio_file.save("audio.mp3") # 保存文件
|
||||
threading.Thread(
|
||||
target=asr,
|
||||
args=(audio_file, language, progressCallbackUrl, taskId),
|
||||
daemon=True,
|
||||
).start()
|
||||
return flask.jsonify({"status": "OK"})
|
||||
|
||||
|
||||
# ignore BEGIN
|
||||
def local_test(progressCallbackUrl: str, taskId: str):
|
||||
"""忽略此方法, 此方法为榜单本地调试使用"""
|
||||
import random
|
||||
import re
|
||||
|
||||
import yaml
|
||||
|
||||
def callback(content):
|
||||
try:
|
||||
if content is None:
|
||||
requests.post(
|
||||
progressCallbackUrl,
|
||||
json={"taskId": taskId, "status": "FINISHED"},
|
||||
)
|
||||
else:
|
||||
requests.post(
|
||||
progressCallbackUrl,
|
||||
json={
|
||||
"taskId": taskId,
|
||||
"status": "RUNNING",
|
||||
"recognition_results": content,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with open(
|
||||
os.getenv("LOCAL_TEST_DATA_PATH", "../dataset/out/data.yaml")
|
||||
) as f:
|
||||
data = yaml.full_load(f)
|
||||
|
||||
voices = data["query_data"][0]["voice"]
|
||||
|
||||
# 首次发送
|
||||
first_send_time = random.randint(3, 5)
|
||||
send_interval = random.random() * 0
|
||||
log.info("首次发送%ss 发送间隔%ss" % (first_send_time, send_interval))
|
||||
time.sleep(first_send_time)
|
||||
|
||||
# 将句子拼接到一起
|
||||
if random.random() < 0.3:
|
||||
log.info("将部分句子合并成单句 每次合并的句子不超过3句")
|
||||
rand_idx = 0
|
||||
rand_sep = [0, len(voices) - 1]
|
||||
while rand_sep[rand_idx] + 1 <= rand_sep[rand_idx + 1] - 1:
|
||||
rand_cursep = random.randint(
|
||||
rand_sep[rand_idx] + 1,
|
||||
min(rand_sep[rand_idx + 1] - 1, rand_sep[rand_idx] + 1 + 3),
|
||||
)
|
||||
rand_sep.insert(rand_idx + 1, rand_cursep)
|
||||
rand_idx += 1
|
||||
merged_voices = []
|
||||
for i, cur_sep in enumerate(rand_sep[:-1]):
|
||||
voice = voices[cur_sep]
|
||||
for j in range(cur_sep + 1, rand_sep[i + 1]):
|
||||
voice["answer"] += voices[j]["answer"]
|
||||
voice["end"] = voices[j]["end"]
|
||||
merged_voices.append(voice)
|
||||
merged_voices.append(voices[rand_sep[-1]])
|
||||
voices = merged_voices
|
||||
|
||||
def split_and_keep(text, delimiters):
|
||||
# 构建正则表达式模式,匹配文本或分隔符
|
||||
pattern = "|".join(re.escape(delimiter) for delimiter in delimiters)
|
||||
pattern = f"(?:[^{pattern}]+|[{pattern}])"
|
||||
return re.findall(pattern, text)
|
||||
|
||||
puncs = [",", ".", "?", "!", ";", ":"]
|
||||
|
||||
para_seq = 0
|
||||
for voice in voices:
|
||||
answer: str = voice["answer"]
|
||||
start_time: float = voice["start"]
|
||||
end_time: float = voice["end"]
|
||||
words = split_and_keep(answer, puncs)
|
||||
temp_words = []
|
||||
for i, word in enumerate(words):
|
||||
if i > 0 and i < len(words) - 1 and random.random() < 0.15:
|
||||
log.info("随机删除word")
|
||||
continue
|
||||
temp_words.extend(word.split(" "))
|
||||
if len(temp_words) == 0:
|
||||
temp_words = words[0].split(" ")
|
||||
words = temp_words
|
||||
answer = " ".join(words)
|
||||
words = list(map(lambda x: x.strip(), words))
|
||||
words = list(filter(lambda x: len(x) > 0, words))
|
||||
|
||||
# 将时间均匀分配到每个字上
|
||||
words_withtime = []
|
||||
word_unittime = (end_time - start_time) / len(words)
|
||||
for i, word in enumerate(words):
|
||||
word_start = start_time + word_unittime * i
|
||||
word_end = word_start + word_unittime
|
||||
words_withtime.append(
|
||||
{
|
||||
"text": word,
|
||||
"start_time": word_start * 1000,
|
||||
"end_time": word_end * 1000,
|
||||
}
|
||||
)
|
||||
|
||||
# 将句子首尾的标点符号时间扩展到字上 标点符号时间为瞬间
|
||||
punc_at = 0
|
||||
while punc_at < len(words) and words[punc_at] in puncs:
|
||||
punc_at += 1
|
||||
if punc_at < len(words):
|
||||
words_withtime[punc_at]["start_time"] = words_withtime[0][
|
||||
"start_time"
|
||||
]
|
||||
for i in range(0, punc_at):
|
||||
words_withtime[i]["start_time"] = words_withtime[0]["start_time"]
|
||||
words_withtime[i]["end_time"] = words_withtime[0]["start_time"]
|
||||
punc_at = len(words) - 1
|
||||
while punc_at >= 0 and words[punc_at] in puncs:
|
||||
punc_at -= 1
|
||||
if punc_at >= 0:
|
||||
words_withtime[punc_at]["end_time"] = words_withtime[-1]["end_time"]
|
||||
for i in range(punc_at + 1, len(words)):
|
||||
words_withtime[i]["start_time"] = (
|
||||
words_withtime[-1]["end_time"] + 0.1
|
||||
)
|
||||
words_withtime[i]["end_time"] = words_withtime[-1]["end_time"] + 0.1
|
||||
|
||||
if random.random() < 0.4 and len(words_withtime) > 1:
|
||||
log.info("发送一次final_result=False")
|
||||
rand_idx = random.randint(1, len(words_withtime) - 1)
|
||||
recognition_result = {
|
||||
"text": " ".join(
|
||||
map(lambda x: x["text"], words_withtime[:rand_idx])
|
||||
),
|
||||
"final_result": False,
|
||||
"para_seq": para_seq,
|
||||
"language": "de",
|
||||
"start_time": start_time * 1000,
|
||||
"end_time": end_time * 1000,
|
||||
"words": words_withtime[:rand_idx],
|
||||
}
|
||||
callback(recognition_result)
|
||||
|
||||
recognition_result = {
|
||||
"text": answer,
|
||||
"final_result": True,
|
||||
"para_seq": para_seq,
|
||||
"language": "de",
|
||||
"start_time": start_time * 1000,
|
||||
"end_time": end_time * 1000,
|
||||
"words": words_withtime,
|
||||
}
|
||||
callback(recognition_result)
|
||||
para_seq += 1
|
||||
log.info("send %s" % para_seq)
|
||||
|
||||
time.sleep(send_interval)
|
||||
|
||||
callback(None)
|
||||
|
||||
|
||||
# ignore END
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=80)
|
||||
Reference in New Issue
Block a user