Add fastapi service

This commit is contained in:
2026-02-04 17:34:39 +08:00
parent 417022a584
commit c3ac3f045b
15 changed files with 3401 additions and 14 deletions

1
.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
.vscode

View File

@@ -1,20 +1,22 @@
FROM corex:4.3.0
FROM git.modelhub.org.cn:9443/enginex-iluvatar/mr-bi150-4.3.0-x86-ubuntu20.04-py3.10-poc-llm-infer:v1.2.3
WORKDIR /root
COPY requirements.txt /root
RUN pip install -r requirements.txt
RUN sed -i 's|deb.debian.org|archive.debian.org|g' /etc/apt/sources.list \
&& sed -i 's|security.debian.org|archive.debian.org|g' /etc/apt/sources.list \
&& sed -i 's|buster-updates|buster|g' /etc/apt/sources.list \
&& printf 'Acquire::Check-Valid-Until "false";\n' > /etc/apt/apt.conf.d/99no-check-valid \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
RUN apt update && apt install -y vim net-tools
RUN pip install funasr==1.2.6 openai-whisper
ADD . /root/
ADD nltk_data.tar.gz /root/
RUN tar -xvzf nltk_data.tar.gz
ENV NLTK_DATA=/root/nltk_data
RUN cp ./replaced_files/mr_v100/cif_predictor.py /usr/local/lib/python3.10/site-packages/funasr/models/paraformer/
EXPOSE 80
ENTRYPOINT ["bash"]
CMD ["./start_funasr.sh"]
COPY requirements.txt /root
RUN pip install -r /root/requirements.txt -i https://nexus.4pd.io/repository/pypi-all/simple
# Patch files
COPY fastapi_funasr.py /root/fastapi_funasr.py
COPY ./replaced_files/mr_v100/cif_predictor.py /usr/local/lib/python3.10/site-packages/funasr/models/paraformer/
COPY ./replaced_files/funasr_nano_model.py /usr/local/lib/python3.10/site-packages/funasr/models/fun_asr_nano/model.py

265
fastapi_funasr.py Normal file
View File

@@ -0,0 +1,265 @@
import os
import time
import argparse
import torchaudio
import torch
import traceback
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Form
import uuid
import uvicorn
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
from funasr.models.fun_asr_nano.model import FunASRNano
os.makedirs("./input", exist_ok=True)
status = "Running"
model = None
device = ""
app = FastAPI()
CUSTOM_DEVICE = os.getenv("CUSTOM_DEVICE", "")
if CUSTOM_DEVICE.startswith("mlu"):
import torch_mlu
elif CUSTOM_DEVICE.startswith("ascend"):
import torch_npu
elif CUSTOM_DEVICE.startswith("pt"):
import torch_dipu
def make_all_dense(module: torch.nn.Module):
for name, param in list(module.named_parameters(recurse=True)):
if getattr(param, "is_sparse", False) and param.is_sparse:
with torch.no_grad():
dense = param.to_dense().contiguous()
parent = module
*mods, leaf = name.split(".")
for m in mods:
parent = getattr(parent, m)
setattr(parent, leaf, torch.nn.Parameter(dense, requires_grad=param.requires_grad))
# 处理 buffer如 running_mean 等)
for name, buf in list(module.named_buffers(recurse=True)):
# PyTorch 稀疏张量 layout 不是 strided
if buf.layout != torch.strided:
dense = buf.to_dense().contiguous()
parent = module
*mods, leaf = name.split(".")
for m in mods:
parent = getattr(parent, m)
parent.register_buffer(leaf, dense, persistent=True)
def split_audio(waveform, sample_rate, segment_seconds=20):
segment_samples = segment_seconds * sample_rate
segments = []
for i in range(0, waveform.shape[1], segment_samples):
segment = waveform[:, i:i + segment_samples]
if segment.shape[1] > 0:
segments.append(segment)
return segments
# def determine_model_type(model_name):
# if "sensevoice" in model_name.lower():
# return "sensevoice"
# elif "whisper" in model_name.lower():
# return "whisper"
# elif "paraformer" in model_name.lower():
# return "paraformer"
# elif "conformer" in model_name.lower():
# return "conformer"
# elif "uniasr" in model_name.lower():
# return "uni_asr"
# else:
# return "unknown"
@app.on_event("startup")
def load_model():
global status, model, device
config = app.state.config
use_gpu = config.get("use_gpu", True)
model_dir = config.get("model_dir", "/model")
model_type = config.get("model_type", "sensevoice")
warmup = config.get("warmup", False)
print(">> Startup config:")
print(" model_dir =", model_dir, flush=True)
print(" model_type =", model_type, flush=True)
print(" use_gpu =", use_gpu, flush=True)
print(" warmup =", warmup, flush=True)
device = "cpu"
if use_gpu:
if CUSTOM_DEVICE.startswith("mlu"):
device = "mlu:0"
elif CUSTOM_DEVICE.startswith("ascend"):
device = "npu:1"
else:
device = "cuda:0"
# 针对加速卡的特殊处理部分
if device == "cuda:0" and torch.cuda.get_device_name() == "Iluvatar BI-V100" and model_type == "whisper":
# 天垓100情况下的Whisper需要绕过不支持算子
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
print(f"device: {device}", flush=True)
dense_convert = False
if device == "cuda:0" and CUSTOM_DEVICE.startswith("pt") and model_type == "whisper":
dense_convert = True
if device.startswith("npu") and model_type == "whisper":
# Ascend NPU 加载whisper的部分会有Sparse部分device不匹配
dense_convert = True
print(f"dense_convert: {dense_convert}", flush=True)
if dense_convert:
model = AutoModel(
model=model_dir,
vad_model=None,
disable_update=True,
device="cpu"
)
make_all_dense(model.model)
model.model.to(dtype=torch.float32, memory_format=torch.contiguous_format)
model.model.to(device)
model.kwargs["device"] = device
else:
# 不使用VAD, punctspk模型就测试原始ASR能力
model = AutoModel(
model=model_dir,
# vad_model="fsmn-vad",
# vad_kwargs={"max_single_segment_time": 30000},
vad_model=None,
device=device,
disable_update=True
)
if device.startswith("npu") or warmup:
# Ascend NPU由于底层设计的不同初始化卡的调度比其他卡更复杂要先进行warmup
print("Start warmup...", flush=True)
res = model.generate(input="warmup.wav")
print("warmup complete.", flush=True)
status = "Success"
def test_funasr(audio_file, lang):
# 推理部分
waveform, sample_rate = torchaudio.load(audio_file)
# print(waveform.shape)
duration = waveform.shape[1] / sample_rate
segments = split_audio(waveform, sample_rate, segment_seconds=20)
generated_text = ""
processing_time = 0
model_type = app.state.config.get("model_type", "sensevoice")
if model_type == "uni_asr":
# uni_asr比较特殊设计就是处理长音频的自带VAD切分的话前20s如果几乎没有人讲话全是音乐直接会报错
# 因为可能会被切掉所有音频导致实际编解码输入为0
ts1 = time.time()
res = model.generate(
input=audio_file
)
generated_text = res[0]["text"]
ts2 = time.time()
processing_time = ts2 - ts1
else:
# 按照切分的音频依次输入
for i, segment in enumerate(segments):
segment_path = f"temp_seg_{i}.wav"
torchaudio.save(segment_path, segment, sample_rate)
ts1 = time.time()
if model_type == "sensevoice":
res = model.generate(
input=segment_path,
cache={},
language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
use_itn=True,
batch_size_s=60,
merge_vad=False,
# merge_length_s=15,
)
text = rich_transcription_postprocess(res[0]["text"])
elif model_type == "whisper":
DecodingOptions = {
"task": "transcribe",
"language": lang,
"beam_size": None,
"fp16": False,
"without_timestamps": False,
"prompt": None,
}
res = model.generate(
DecodingOptions=DecodingOptions,
input=segment_path,
batch_size_s=0,
)
text = res[0]["text"]
elif model_type == "paraformer":
res = model.generate(
input=segment_path,
batch_size_s=300
)
text = res[0]["text"]
# paraformer模型会一个字一个字输出中间夹太多空格会影响1-cer的结果
text = text.replace(" ", "")
elif model_type == "conformer":
res = model.generate(
input=segment_path,
batch_size_s=300
)
text = res[0]["text"]
# elif model_type == "uni_asr":
# if i == 0:
# os.remove(segment_path)
# continue
# res = model.generate(
# input=segment_path
# )
# text = res[0]["text"]
else:
raise RuntimeError("unknown model type")
ts2 = time.time()
generated_text += text
processing_time += (ts2 - ts1)
os.remove(segment_path)
rtf = processing_time / duration
print("Text:", generated_text, flush=True)
print(f"Audio duration:\t{duration:.3f} s", flush=True)
print(f"Elapsed:\t{processing_time:.3f} s", flush=True)
print(f"RTF = {processing_time:.3f}/{duration:.3f} = {rtf:.3f}", flush=True)
return generated_text
@app.get("/health")
def health():
if status=="Running":
return {
"status":"loading model"
}
ret = {
"status": "ok" if status == "Success" else "failed",
}
return ret
@app.post("/transduce")
def transduce(
audio: UploadFile = File(...),
lang: str = Form("zh"),
background_tasks: BackgroundTasks = None
):
try:
file_path = f"./input/{uuid.uuid4()}.wav"
with open(file_path, "wb") as f:
f.write(audio.file.read())
background_tasks.add_task(os.remove, file_path)
generated_text = test_funasr(file_path, lang)
return {"generated_text": generated_text}
except Exception:
raise HTTPException(status_code=500, detail=f"Processing failed: \n{traceback.format_exc()}")
# if __name__ == "__main__":
# uvicorn.run("fastapi_funasr:app", host="0.0.0.0", port=1111, workers=1)

Binary file not shown.

View File

@@ -1 +0,0 @@
朋友们晚上好欢迎大家来参加今天晚上的活动谢谢大家。这是我第四次办年度演讲前三次呢因为疫情的原因都在小米科技园内举办。现场呢人很少。这是第四次我们仔细想了想我们还是想办一个比较大的聚会。然后呢让我们的新朋友老朋友一起聚一聚。今天的话呢我们就在北京的国家会议中心呢举办了这么一个活动。现场呢来了很多人大概有3500人。还有很多很多的朋友呢通过观看直播的方式来参与。再一次呢对大家的参加表示感谢谢谢大家。两个月前我参加了今年武汉大学的毕业典礼。今年呢是武汉大学建校130周年作为校友被母校邀请在毕业典礼上致辞这对我来说是至高无上的荣誉。站在讲台的那一刻面对全校师生关于武大的所有的记忆一下子涌现在脑海里。今天呢我就先和大家聊聊武大往事。那还是36年前1987年我呢考上了武汉大学的计算机系。在武汉大学的图书馆里看了一本书《硅谷之火》建立了我一生的梦想。看完书以后热血沸腾激动得睡不着觉。我还记得那天晚上星光很亮。我就在武大的操场上就是屏幕上这个操场走了一圈又一圈走了整整一个晚上。我心里有团火我也想办一个伟大的公司就是这样梦想之火在我心里彻底点燃了。但是一个大一的新生但是一个大一的新生一个从县城里出来的年轻人什么也不会什么也没有就想创办一家伟大的公司这不就是天方夜谭吗这么离谱的一个梦想该如何实现呢那天晚上我想了一整晚上说实话越想越糊涂完全理不清头绪后来我在想哎干脆别想了把书念好是正事所以呢我就下定决心认认真真读书那么我怎么能够把书读得不同凡响呢

27
main.py Normal file
View File

@@ -0,0 +1,27 @@
import argparse
import uvicorn
from fastapi_funasr import app
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", type=str, default="/model", help="model directory")
parser.add_argument("--model_type", type=str, default="sensevoice", help="model type, e.g. sensevoice")
parser.add_argument("--use_gpu", action="store_true", default=True)
parser.add_argument("--warmup", action="store_true", help="whether do warmup when first initializing model")
parser.add_argument("--port", type=int, default=8000, help="service port")
args = parser.parse_args()
# 将参数加到 app.state 中
app.state.config = {
"model_dir": args.model_dir,
"model_type": args.model_type,
"use_gpu": args.use_gpu, # True
"warmup": args.warmup
}
uvicorn.run("fastapi_funasr:app",
host="0.0.0.0",
port=args.port,
workers=1
)

Binary file not shown.

View File

@@ -0,0 +1,762 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import torch
import logging
import numpy as np
from funasr.register import tables
from funasr.train_utils.device_funcs import to_device
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from torch.cuda.amp import autocast
@tables.register("predictor_classes", "CifPredictor")
class CifPredictor(torch.nn.Module):
def __init__(
self,
idim,
l_order,
r_order,
threshold=1.0,
dropout=0.1,
smooth_factor=1.0,
noise_threshold=0,
tail_threshold=0.45,
):
super().__init__()
self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
self.cif_output = torch.nn.Linear(idim, 1)
self.dropout = torch.nn.Dropout(p=dropout)
self.threshold = threshold
self.smooth_factor = smooth_factor
self.noise_threshold = noise_threshold
self.tail_threshold = tail_threshold
def forward(
self,
hidden,
target_label=None,
mask=None,
ignore_id=-1,
mask_chunk_predictor=None,
target_label_length=None,
):
with autocast(False):
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
memory = self.cif_conv1d(queries)
output = memory + context
output = self.dropout(output)
output = output.transpose(1, 2)
output = torch.relu(output)
output = self.cif_output(output)
alphas = torch.sigmoid(output)
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
if mask is not None:
mask = mask.transpose(-1, -2).float()
alphas = alphas * mask
if mask_chunk_predictor is not None:
alphas = alphas * mask_chunk_predictor
alphas = alphas.squeeze(-1)
mask = mask.squeeze(-1)
if target_label_length is not None:
target_length = target_label_length
elif target_label is not None:
target_length = (target_label != ignore_id).float().sum(-1)
else:
target_length = None
token_num = alphas.sum(-1)
if target_length is not None:
alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
elif self.tail_threshold > 0.0:
hidden, alphas, token_num = self.tail_process_fn(
hidden, alphas, token_num, mask=mask
)
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
if target_length is None and self.tail_threshold > 0.0:
token_num_int = torch.max(token_num).type(torch.int32).item()
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
return acoustic_embeds, token_num, alphas, cif_peak
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
b, t, d = hidden.size()
tail_threshold = self.tail_threshold
if mask is not None:
zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
ones_t = torch.ones_like(zeros_t)
mask_1 = torch.cat([mask, zeros_t], dim=1)
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
alphas = torch.cat([alphas, zeros_t], dim=1)
alphas = torch.add(alphas, tail_threshold)
else:
tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
tail_threshold = torch.reshape(tail_threshold, (1, 1))
alphas = torch.cat([alphas, tail_threshold], dim=1)
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
hidden = torch.cat([hidden, zeros], dim=1)
token_num = alphas.sum(dim=-1)
token_num_floor = torch.floor(token_num)
return hidden, alphas, token_num_floor
def gen_frame_alignments(
self, alphas: torch.Tensor = None, encoder_sequence_length: torch.Tensor = None
):
batch_size, maximum_length = alphas.size()
int_type = torch.int32
is_training = self.training
if is_training:
token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
else:
token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
max_token_num = torch.max(token_num).item()
alphas_cumsum = torch.cumsum(alphas, dim=1)
alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
index = torch.ones([batch_size, max_token_num], dtype=int_type)
index = torch.cumsum(index, dim=1)
index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
index_div_bool_zeros = index_div.eq(0)
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
index_div_bool_zeros_count = torch.clamp(
index_div_bool_zeros_count, 0, encoder_sequence_length.max()
)
token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
index_div_bool_zeros_count *= token_num_mask
index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(
1, 1, maximum_length
)
ones = torch.ones_like(index_div_bool_zeros_count_tile)
zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
ones = torch.cumsum(ones, dim=2)
cond = index_div_bool_zeros_count_tile == ones
index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
predictor_mask = (
(~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max()))
.type(int_type)
.to(encoder_sequence_length.device)
)
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
predictor_alignments = index_div_bool_zeros_count_tile_out
predictor_alignments_length = predictor_alignments.sum(-1).type(
encoder_sequence_length.dtype
)
return predictor_alignments.detach(), predictor_alignments_length.detach()
@tables.register("predictor_classes", "CifPredictorV2")
class CifPredictorV2(torch.nn.Module):
def __init__(
self,
idim,
l_order,
r_order,
threshold=1.0,
dropout=0.1,
smooth_factor=1.0,
noise_threshold=0,
tail_threshold=0.0,
tf2torch_tensor_name_prefix_torch="predictor",
tf2torch_tensor_name_prefix_tf="seq2seq/cif",
tail_mask=True,
):
super().__init__()
self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1)
self.cif_output = torch.nn.Linear(idim, 1)
self.dropout = torch.nn.Dropout(p=dropout)
self.threshold = threshold
self.smooth_factor = smooth_factor
self.noise_threshold = noise_threshold
self.tail_threshold = tail_threshold
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
self.tail_mask = tail_mask
def forward(
self,
hidden,
target_label=None,
mask=None,
ignore_id=-1,
mask_chunk_predictor=None,
target_label_length=None,
):
with autocast(False):
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
output = torch.relu(self.cif_conv1d(queries))
output = output.transpose(1, 2)
output = self.cif_output(output)
alphas = torch.sigmoid(output)
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
if mask is not None:
mask = mask.transpose(-1, -2).float()
alphas = alphas * mask
if mask_chunk_predictor is not None:
alphas = alphas * mask_chunk_predictor
alphas = alphas.squeeze(-1)
mask = mask.squeeze(-1)
if target_label_length is not None:
target_length = target_label_length.squeeze(-1)
elif target_label is not None:
target_length = (target_label != ignore_id).float().sum(-1)
else:
target_length = None
token_num = alphas.sum(-1)
if target_length is not None:
alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
elif self.tail_threshold > 0.0:
if self.tail_mask:
hidden, alphas, token_num = self.tail_process_fn(
hidden, alphas, token_num, mask=mask
)
else:
hidden, alphas, token_num = self.tail_process_fn(
hidden, alphas, token_num, mask=None
)
acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
if target_length is None and self.tail_threshold > 0.0:
token_num_int = torch.max(token_num).type(torch.int32).item()
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
return acoustic_embeds, token_num, alphas, cif_peak
def forward_chunk(self, hidden, cache=None, **kwargs):
is_final = kwargs.get("is_final", False)
batch_size, len_time, hidden_size = hidden.shape
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
output = torch.relu(self.cif_conv1d(queries))
output = output.transpose(1, 2)
output = self.cif_output(output)
alphas = torch.sigmoid(output)
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
alphas = alphas.squeeze(-1)
token_length = []
list_fires = []
list_frames = []
cache_alphas = []
cache_hiddens = []
if cache is not None and "chunk_size" in cache:
alphas[:, : cache["chunk_size"][0]] = 0.0
if not is_final:
alphas[:, sum(cache["chunk_size"][:2]) :] = 0.0
if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device)
cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device)
hidden = torch.cat((cache["cif_hidden"], hidden), dim=1)
alphas = torch.cat((cache["cif_alphas"], alphas), dim=1)
if cache is not None and is_final:
tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device)
tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device)
tail_alphas = torch.tile(tail_alphas, (batch_size, 1))
hidden = torch.cat((hidden, tail_hidden), dim=1)
alphas = torch.cat((alphas, tail_alphas), dim=1)
len_time = alphas.shape[1]
for b in range(batch_size):
integrate = 0.0
frames = torch.zeros((hidden_size), device=hidden.device)
list_frame = []
list_fire = []
for t in range(len_time):
alpha = alphas[b][t]
if alpha + integrate < self.threshold:
integrate += alpha
list_fire.append(integrate)
frames += alpha * hidden[b][t]
else:
frames += (self.threshold - integrate) * hidden[b][t]
list_frame.append(frames)
integrate += alpha
list_fire.append(integrate)
integrate -= self.threshold
frames = integrate * hidden[b][t]
cache_alphas.append(integrate)
if integrate > 0.0:
cache_hiddens.append(frames / integrate)
else:
cache_hiddens.append(frames)
token_length.append(torch.tensor(len(list_frame), device=alphas.device))
list_fires.append(list_fire)
list_frames.append(list_frame)
cache["cif_alphas"] = torch.stack(cache_alphas, axis=0)
cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0)
cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0)
cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0)
max_token_len = max(token_length)
if max_token_len == 0:
return hidden, torch.stack(token_length, 0), None, None
list_ls = []
for b in range(batch_size):
pad_frames = torch.zeros(
(max_token_len - token_length[b], hidden_size), device=alphas.device
)
if token_length[b] == 0:
list_ls.append(pad_frames)
else:
list_frames[b] = torch.stack(list_frames[b])
list_ls.append(torch.cat((list_frames[b], pad_frames), dim=0))
cache["cif_alphas"] = torch.stack(cache_alphas, axis=0)
cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0)
cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0)
cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0)
return torch.stack(list_ls, 0), torch.stack(token_length, 0), None, None
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
b, t, d = hidden.size()
tail_threshold = self.tail_threshold
if mask is not None:
zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
ones_t = torch.ones_like(zeros_t)
mask_1 = torch.cat([mask, zeros_t], dim=1)
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
alphas = torch.cat([alphas, zeros_t], dim=1)
alphas = torch.add(alphas, tail_threshold)
else:
tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
tail_threshold = torch.reshape(tail_threshold, (1, 1))
if b > 1:
alphas = torch.cat([alphas, tail_threshold.repeat(b, 1)], dim=1)
else:
alphas = torch.cat([alphas, tail_threshold], dim=1)
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
hidden = torch.cat([hidden, zeros], dim=1)
token_num = alphas.sum(dim=-1)
token_num_floor = torch.floor(token_num)
return hidden, alphas, token_num_floor
def gen_frame_alignments(
self, alphas: torch.Tensor = None, encoder_sequence_length: torch.Tensor = None
):
batch_size, maximum_length = alphas.size()
int_type = torch.int32
is_training = self.training
if is_training:
token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
else:
token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
max_token_num = torch.max(token_num).item()
alphas_cumsum = torch.cumsum(alphas, dim=1)
alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
index = torch.ones([batch_size, max_token_num], dtype=int_type)
index = torch.cumsum(index, dim=1)
index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
index_div_bool_zeros = index_div.eq(0)
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
index_div_bool_zeros_count = torch.clamp(
index_div_bool_zeros_count, 0, encoder_sequence_length.max()
)
token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
index_div_bool_zeros_count *= token_num_mask
index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(
1, 1, maximum_length
)
ones = torch.ones_like(index_div_bool_zeros_count_tile)
zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
ones = torch.cumsum(ones, dim=2)
cond = index_div_bool_zeros_count_tile == ones
index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
predictor_mask = (
(~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max()))
.type(int_type)
.to(encoder_sequence_length.device)
)
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
predictor_alignments = index_div_bool_zeros_count_tile_out
predictor_alignments_length = predictor_alignments.sum(-1).type(
encoder_sequence_length.dtype
)
return predictor_alignments.detach(), predictor_alignments_length.detach()
@tables.register("predictor_classes", "CifPredictorV2Export")
class CifPredictorV2Export(torch.nn.Module):
def __init__(self, model, **kwargs):
super().__init__()
self.pad = model.pad
self.cif_conv1d = model.cif_conv1d
self.cif_output = model.cif_output
self.threshold = model.threshold
self.smooth_factor = model.smooth_factor
self.noise_threshold = model.noise_threshold
self.tail_threshold = model.tail_threshold
def forward(
self,
hidden: torch.Tensor,
mask: torch.Tensor,
):
alphas, token_num = self.forward_cnn(hidden, mask)
mask = mask.transpose(-1, -2).float()
mask = mask.squeeze(-1)
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
acoustic_embeds, cif_peak = cif_v1_export(hidden, alphas, self.threshold)
return acoustic_embeds, token_num, alphas, cif_peak
def forward_cnn(
self,
hidden: torch.Tensor,
mask: torch.Tensor,
):
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
output = torch.relu(self.cif_conv1d(queries))
output = output.transpose(1, 2)
output = self.cif_output(output)
alphas = torch.sigmoid(output)
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
mask = mask.transpose(-1, -2).float()
alphas = alphas * mask
alphas = alphas.squeeze(-1)
token_num = alphas.sum(-1)
return alphas, token_num
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
b, t, d = hidden.size()
tail_threshold = self.tail_threshold
zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
ones_t = torch.ones_like(zeros_t)
mask_1 = torch.cat([mask, zeros_t], dim=1)
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
alphas = torch.cat([alphas, zeros_t], dim=1)
alphas = torch.add(alphas, tail_threshold)
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
hidden = torch.cat([hidden, zeros], dim=1)
token_num = alphas.sum(dim=-1)
token_num_floor = torch.floor(token_num)
return hidden, alphas, token_num_floor
@torch.jit.script
def cif_v1_export(hidden, alphas, threshold: float):
device = hidden.device
dtype = hidden.dtype
batch_size, len_time, hidden_size = hidden.size()
threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
# prefix_sum = torch.cumsum(alphas, dim=1)
prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(
torch.float32
) # cumsum precision degradation cause wrong result in extreme
prefix_sum_floor = torch.floor(prefix_sum)
dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
dislocation_prefix_sum_floor[:, 0] = 0
dislocation_diff = prefix_sum_floor - dislocation_prefix_sum_floor
fire_idxs = dislocation_diff > 0
fires[fire_idxs] = 1
fires = fires + prefix_sum - prefix_sum_floor
# prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).repeat((1, 1, hidden_size)) * hidden, dim=1)
frames = prefix_sum_hidden[fire_idxs]
shift_frames = torch.roll(frames, 1, dims=0)
batch_len = fire_idxs.sum(1)
batch_idxs = torch.cumsum(batch_len, dim=0)
shift_batch_idxs = torch.roll(batch_idxs, 1, dims=0)
shift_batch_idxs[0] = 0
shift_frames[shift_batch_idxs] = 0
remains = fires - torch.floor(fires)
# remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
remain_frames = remains[fire_idxs].unsqueeze(-1).repeat((1, hidden_size)) * hidden[fire_idxs]
shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
shift_remain_frames[shift_batch_idxs] = 0
frames = frames - shift_frames + shift_remain_frames - remain_frames
# max_label_len = batch_len.max()
max_label_len = alphas.sum(dim=-1)
max_label_len = torch.floor(max_label_len).max().to(dtype=torch.int64)
# frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
frame_fires_idxs = indices < batch_len.unsqueeze(1)
frame_fires[frame_fires_idxs] = frames
return frame_fires, fires
@torch.jit.script
def cif_export(hidden, alphas, threshold: float):
batch_size, len_time, hidden_size = hidden.size()
threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
# loop varss
integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=hidden.device)
frame = torch.zeros([batch_size, hidden_size], dtype=hidden.dtype, device=hidden.device)
# intermediate vars along time
list_fires = []
list_frames = []
for t in range(len_time):
alpha = alphas[:, t]
distribution_completion = (
torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device) - integrate
)
integrate += alpha
list_fires.append(integrate)
fire_place = integrate >= threshold
integrate = torch.where(
fire_place,
integrate - torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device),
integrate,
)
cur = torch.where(fire_place, distribution_completion, alpha)
remainds = alpha - cur
frame += cur[:, None] * hidden[:, t, :]
list_frames.append(frame)
frame = torch.where(
fire_place[:, None].repeat(1, hidden_size), remainds[:, None] * hidden[:, t, :], frame
)
fires = torch.stack(list_fires, 1)
frames = torch.stack(list_frames, 1)
fire_idxs = fires >= threshold
frame_fires = torch.zeros_like(hidden)
max_label_len = frames[0, fire_idxs[0]].size(0)
for b in range(batch_size):
frame_fire = frames[b, fire_idxs[b]]
frame_len = frame_fire.size(0)
frame_fires[b, :frame_len, :] = frame_fire
if frame_len >= max_label_len:
max_label_len = frame_len
frame_fires = frame_fires[:, :max_label_len, :]
return frame_fires, fires
class mae_loss(torch.nn.Module):
def __init__(self, normalize_length=False):
super(mae_loss, self).__init__()
self.normalize_length = normalize_length
self.criterion = torch.nn.L1Loss(reduction="sum")
def forward(self, token_length, pre_token_length):
loss_token_normalizer = token_length.size(0)
if self.normalize_length:
loss_token_normalizer = token_length.sum().type(torch.float32)
loss = self.criterion(token_length, pre_token_length)
loss = loss / loss_token_normalizer
return loss
def cif(hidden, alphas, threshold):
batch_size, len_time, hidden_size = hidden.size()
# loop varss
integrate = torch.zeros([batch_size], device=hidden.device)
frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
# intermediate vars along time
list_fires = []
list_frames = []
for t in range(len_time):
alpha = alphas[:, t]
distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate
integrate += alpha
list_fires.append(integrate)
fire_place = integrate >= threshold
integrate = torch.where(
fire_place, integrate - torch.ones([batch_size], device=hidden.device), integrate
)
cur = torch.where(fire_place, distribution_completion, alpha)
remainds = alpha - cur
frame += cur[:, None] * hidden[:, t, :]
list_frames.append(frame)
frame = torch.where(
fire_place[:, None].repeat(1, hidden_size), remainds[:, None] * hidden[:, t, :], frame
)
fires = torch.stack(list_fires, 1)
frames = torch.stack(list_frames, 1)
list_ls = []
len_labels = torch.round(alphas.sum(-1)).int()
max_label_len = len_labels.max()
for b in range(batch_size):
fire = fires[b, :]
l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
pad_l = torch.zeros([max_label_len - l.size(0), hidden_size], device=hidden.device)
list_ls.append(torch.cat([l, pad_l], 0))
return torch.stack(list_ls, 0), fires
def cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=False):
batch_size, len_time = alphas.size()
device = alphas.device
dtype = alphas.dtype
threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
if torch.cuda.get_device_name() == "Iluvatar BI-V100":
# the normal branch causes wrong result in bi-100, and leads to exception in later stages
prefix_sum = torch.cumsum(alphas, dim=1)
else:
prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(
torch.float32
) # cumsum precision degradation cause wrong result in extreme
prefix_sum_floor = torch.floor(prefix_sum)
dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
dislocation_prefix_sum_floor[:, 0] = 0
dislocation_diff = prefix_sum_floor - dislocation_prefix_sum_floor
fire_idxs = dislocation_diff > 0
fires[fire_idxs] = 1
fires = fires + prefix_sum - prefix_sum_floor
if return_fire_idxs:
return fires, fire_idxs
return fires
def cif_v1(hidden, alphas, threshold):
fires, fire_idxs = cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=True)
device = hidden.device
dtype = hidden.dtype
batch_size, len_time, hidden_size = hidden.size()
# frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
# prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).repeat((1, 1, hidden_size)) * hidden, dim=1)
frames = prefix_sum_hidden[fire_idxs]
shift_frames = torch.roll(frames, 1, dims=0)
batch_len = fire_idxs.sum(1)
batch_idxs = torch.cumsum(batch_len, dim=0)
shift_batch_idxs = torch.roll(batch_idxs, 1, dims=0)
shift_batch_idxs[0] = 0
shift_frames[shift_batch_idxs] = 0
remains = fires - torch.floor(fires)
# remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
remain_frames = remains[fire_idxs].unsqueeze(-1).repeat((1, hidden_size)) * hidden[fire_idxs]
shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
shift_remain_frames[shift_batch_idxs] = 0
frames = frames - shift_frames + shift_remain_frames - remain_frames
# max_label_len = batch_len.max()
max_label_len = (
torch.round(alphas.sum(-1)).int().max()
) # torch.round to calculate the max length
# frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
frame_fires_idxs = indices < batch_len.unsqueeze(1)
frame_fires[frame_fires_idxs] = frames
return frame_fires, fires
def cif_wo_hidden(alphas, threshold):
batch_size, len_time = alphas.size()
# loop varss
integrate = torch.zeros([batch_size], device=alphas.device)
# intermediate vars along time
list_fires = []
for t in range(len_time):
alpha = alphas[:, t]
integrate += alpha
list_fires.append(integrate)
fire_place = integrate >= threshold
integrate = torch.where(
fire_place,
integrate - torch.ones([batch_size], device=alphas.device) * threshold,
integrate,
)
fires = torch.stack(list_fires, 1)
return fires

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,746 @@
import logging
import os
import random
import re
import string
import time
import traceback
from typing import Union
import torch
import torch.nn as nn
from funasr.metrics.compute_acc import compute_accuracy
from funasr.register import tables
from funasr.train_utils.device_funcs import force_gatherable, to_device
from funasr.utils.datadir_writer import DatadirWriter
from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
from transformers import AutoConfig, AutoModelForCausalLM
from funasr.models.fun_asr_nano.ctc import CTC
from funasr.models.fun_asr_nano.tools.utils import forced_align
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
@tables.register("model_classes", "FunASRNano")
class FunASRNano(nn.Module):
def __init__(
self,
audio_encoder: str = None,
audio_encoder_conf: dict = None,
audio_adaptor: str = None,
audio_adaptor_conf: dict = None,
llm: str = None,
llm_conf: dict = None,
input_size: int = 80,
length_normalized_loss: bool = False,
**kwargs,
):
super().__init__()
# audio encoder
hub = audio_encoder_conf.get("hub", None)
self.audio_encoder_activation_checkpoint = audio_encoder_conf.get(
"activation_checkpoint", False
)
if hub == "ms":
from funasr import AutoModel
model = AutoModel(model=audio_encoder, model_revision="master")
audio_encoder_output_size = (
model.model.encoder_output_size
if hasattr(model.model, "encoder_output_size")
else -1
)
audio_encoder = (
model.model.model.encoder if hasattr(model.model, "model") else model.model.encoder
)
else:
encoder_class = tables.encoder_classes.get(audio_encoder)
audio_encoder = encoder_class(input_size=input_size, **audio_encoder_conf)
audio_encoder_output_size = audio_encoder.output_size()
freeze = audio_encoder_conf.get("freeze", True)
if freeze:
for _, param in audio_encoder.named_parameters():
param.requires_grad = False
audio_encoder.eval()
self.audio_encoder = audio_encoder
# llm
self.llm = None
init_param_path = llm_conf.get("init_param_path", None)
llm_dim = None
llm_load_kwargs = llm_conf.get("load_kwargs", {})
config = AutoConfig.from_pretrained(init_param_path)
model = AutoModelForCausalLM.from_config(config, **llm_load_kwargs)
freeze = llm_conf.get("freeze", True)
if freeze:
for _, param in model.named_parameters():
param.requires_grad = False
model.eval()
if llm_conf.get("activation_checkpoint", False):
model.gradient_checkpointing_enable()
self.llm_dtype = llm_conf.get("llm_dtype", "fp32")
self.llm = model.to(dtype_map[self.llm_dtype])
llm_dim = model.get_input_embeddings().weight.shape[-1]
# adaptor
adaptor_class = tables.adaptor_classes.get(audio_adaptor)
if audio_encoder_output_size > 0:
audio_adaptor_conf["encoder_dim"] = audio_encoder_output_size
audio_adaptor_conf["llm_dim"] = (
llm_dim if llm_dim is not None else audio_adaptor_conf["llm_dim"]
)
audio_adaptor = adaptor_class(**audio_adaptor_conf)
freeze = audio_adaptor_conf.get("freeze", False)
if freeze:
for _, param in audio_adaptor.named_parameters():
param.requires_grad = False
audio_adaptor.eval()
self.audio_adaptor = audio_adaptor
self.use_low_frame_rate = audio_adaptor_conf.get("use_low_frame_rate", False)
# ctc decoder
self.ctc_decoder = None
# TODO: fix table name
ctc_decoder_class = tables.adaptor_classes.get(kwargs.get("ctc_decoder", None))
if ctc_decoder_class is not None:
ctc_tokenizer = (
kwargs.get("ctc_tokenizer", None)
if "ctc_tokenizer" in kwargs
else kwargs["dataset_conf"]["ctc_tokenizer"]
)
ctc_tokenizer_conf = (
kwargs.get("ctc_tokenizer_conf", None)
if "ctc_tokenizer_conf" in kwargs
else kwargs["dataset_conf"]["ctc_tokenizer_conf"]
)
if ctc_tokenizer is not None and ctc_tokenizer_conf is not None:
ctc_tokenizer_class = tables.tokenizer_classes.get(ctc_tokenizer)
ctc_tokenizer = ctc_tokenizer_class(**ctc_tokenizer_conf)
self.ctc_tokenizer = ctc_tokenizer
assert ctc_tokenizer is not None, f"ctc_tokenizer must be set"
ctc_vocab_size = kwargs.get("ctc_vocab_size", 60515)
ctc_decoder_conf = kwargs.get("ctc_decoder_conf", {})
if audio_encoder_output_size > 0:
ctc_decoder_conf["encoder_dim"] = audio_encoder_output_size
self.ctc_decoder = ctc_decoder_class(**ctc_decoder_conf)
init_param_path = ctc_decoder_conf.get("init_param_path", None)
if init_param_path is not None:
src_state = torch.load(init_param_path, map_location="cpu")
flag = self.ctc_decoder.load_state_dict(src_state, strict=False)
logging.info(f"Loading ctc_decoder ckpt: {init_param_path}, status: {flag}")
freeze = ctc_decoder_conf.get("freeze", False)
if freeze:
for _, param in self.ctc_decoder.named_parameters():
param.requires_grad = False
self.ctc_decoder.eval()
ctc_conf = kwargs.get("ctc_conf", {})
self.blank_id = ctc_conf.get("blank_id", ctc_vocab_size - 1)
self.ctc_weight = kwargs.get("ctc_weight", 0.3)
self.ctc = CTC(
odim=ctc_vocab_size,
encoder_output_size=audio_encoder_output_size,
blank_id=self.blank_id,
**ctc_conf,
)
self.detach_ctc_decoder = kwargs.get("detach_ctc_decoder", True)
self.error_calculator = None
self.length_normalized_loss = length_normalized_loss
rank = int(os.environ.get("RANK", 0))
logging.info(f"rank: {rank}, model is builded.")
def forward(
self,
speech: torch.Tensor = None,
speech_lengths: torch.Tensor = None,
input_ids: torch.Tensor = None,
attention_mask: torch.Tensor = None,
labels_ids: torch.Tensor = None,
fbank_beg: torch.Tensor = None,
fbank_mask: torch.Tensor = None,
**kwargs,
):
batch_size, token_num = input_ids.shape
stats = {}
input_ids[input_ids < 0] = 0
inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
if speech is not None:
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size_speech, frames, _ = speech.shape
# audio encoder
if self.audio_encoder_activation_checkpoint:
from torch.utils.checkpoint import checkpoint
encoder_out, encoder_out_lens = checkpoint(
self.encode, speech, speech_lengths, use_reentrant=False
)
else:
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# audio_adaptor
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
batch_size, token_num, dims = inputs_embeds.shape
fake_token_len = kwargs.get("fake_token_len")
fake_token_len[fake_token_len < 0] = 0
fbank_beg[fbank_beg < 0] = 0
speech_idx = 0
for batch_idx in range(batch_size):
for turn_id in range(fbank_beg.shape[1]):
fbank_beg_idx = fbank_beg[batch_idx, turn_id].item()
if fbank_beg_idx > 0:
speech_token_len = fake_token_len[batch_idx, turn_id]
speech_token = encoder_out[speech_idx, :speech_token_len, :]
try:
inputs_embeds[
batch_idx,
fbank_beg_idx : fbank_beg_idx + speech_token_len,
:,
] = speech_token
except Exception as e:
logging.error(f"{str(e)}, {traceback.format_exc()}")
logging.info(
f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}"
)
speech_token_len = encoder_out_lens[speech_idx].item()
speech_token = encoder_out[speech_idx, :speech_token_len, :]
inputs_embeds[
batch_idx,
fbank_beg_idx : fbank_beg_idx + speech_token_len,
:,
] = speech_token
speech_idx += 1
stats["batch_size_speech"] = batch_size_speech
stats["batch_size_x_frames"] = frames * batch_size_speech
stats["batch_size_real_frames"] = speech_lengths.sum().item()
stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
device_type = next(self.parameters()).device.type
with torch.autocast(
device_type=device_type if device_type in ["cuda", "xpu", "mps"] else "cpu",
enabled=True if self.llm_dtype != "fp32" else False,
dtype=dtype_map[self.llm_dtype],
):
labels_ids[labels_ids == -1] = -100
attention_mask[attention_mask < 0] = 0
model_outputs = self.llm(
inputs_embeds=inputs_embeds.to(dtype_map[self.llm_dtype]),
attention_mask=attention_mask,
labels=labels_ids,
)
loss = model_outputs.loss
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
stats["acc"] = acc_att
stats["loss"] = torch.clone(loss.detach())
stats["batch_size"] = batch_size
stats["batch_size_x_tokens"] = token_num * batch_size
stats["batch_size_real_tokens"] = attention_mask.sum().item()
stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
dialog_turns = (fbank_beg > 0).sum(-1)
dialog_turns_max = torch.max(dialog_turns).int().item()
dialog_turns_avg = dialog_turns.sum().item() / batch_size
stats["dialog_turns_max"] = dialog_turns_max
stats["dialog_turns_avg"] = dialog_turns_avg
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((labels_ids > 0 + 1).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def forward_export(self, speech, speech_lengths, **kwargs):
x, olens = self.audio_encoder(speech, speech_lengths)
encoder_out, encoder_out_lens = self.audio_adaptor(x, olens)
return encoder_out, encoder_out_lens
def encode(self, speech, speech_lengths):
# audio encoder
encoder_out, encoder_out_lens = self.audio_encoder(speech, speech_lengths)
return encoder_out, encoder_out_lens
def data_template(self, data):
system, user, assistant = [], [], []
for i, item in enumerate(data):
role = item["role"]
content = item["content"]
if role == "system":
system.append(content)
elif role == "user":
if "audio" in item:
audio = item["audio"]
content = [content, audio]
user.append(content)
elif role == "assistant":
assistant.append(content)
system = system * len(user)
contents = {
"system": system,
"user": user,
"assistant": assistant,
}
return contents
def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs):
system = contents["system"]
user = contents["user"]
assistant = contents["assistant"]
pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)")
do_think = True
sys_prompt = True
if "dataset_conf" in kwargs:
do_think = kwargs["dataset_conf"].get("do_think", True)
sys_prompt = kwargs["dataset_conf"].get("sys_prompt", True)
input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg, fake_token_len = (
[],
[],
[],
[],
[],
[],
[],
)
input_source_ids = []
for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)):
if i >= kwargs.get("multiturn_num_max", 5):
break
if len(input_ids) > kwargs.get("max_token_length", 1500):
break
if isinstance(user_prompt, (list, tuple)):
user_prompt, audio = user_prompt
if i == 0:
if kwargs.get("infer_with_assistant_input", False):
source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}"
if not sys_prompt:
source_input = f"<|im_start|>user\n{user_prompt}"
else:
source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
if not sys_prompt:
source_input = (
f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
)
else:
if kwargs.get("infer_with_assistant_input", False):
source_input = f"<|im_start|>user\n{user_prompt}"
else:
source_input = (
f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
)
if not do_think:
source_input += "<think>\n\n</think>\n\n"
if kwargs.get("prev_text", None) is not None:
source_input += kwargs["prev_text"]
splits = pattern.split(source_input)
source_ids = []
fbank_mask_i = []
fake_token_len_i = 0
fbank_beg_i = -1
speech, speech_lengths = [], []
for k, sub_str in enumerate(splits):
if not sub_str.startswith("<|startofspeech|>"):
sub_token = tokenizer.encode(sub_str)
source_ids += sub_token
fbank_mask_i += [0] * len(sub_token)
else:
sub_str = sub_str.replace("<|startofspeech|>", "").replace(
"<|endofspeech|>", ""
)
if sub_str.startswith("!"):
sub_str = sub_str[1:]
if sub_str.startswith("!"): # !!: audio sample point
sub_str = audio
try:
time1 = time.perf_counter()
data_src = load_audio_text_image_video(
sub_str, fs=frontend.fs, **kwargs
)
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
except Exception as e:
logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}")
speech, speech_lengths = extract_fbank(
data_src,
data_type=kwargs.get("data_type", "sound"),
frontend=frontend,
is_final=True,
) # speech: [b, T, d]
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = (
speech_lengths.sum().item()
* frontend.frame_shift
* frontend.lfr_n
/ 1000
)
if self.use_low_frame_rate:
olens = 1 + (speech_lengths[0].item() - 3 + 2 * 1) // 2
olens = 1 + (olens - 3 + 2 * 1) // 2
fake_token_len_i = (olens - 1) // 2 + 1
else:
fake_token_len_i = speech_lengths[0].item()
fake_token = [0] * fake_token_len_i
fbank_beg_i = len(source_ids)
source_ids += fake_token
fbank_mask_i += [1] * len(fake_token)
fbank_beg += [fbank_beg_i + len(input_ids)]
fake_token_len += [fake_token_len_i]
source_mask = [-100] * len(source_ids)
target_out = f"{target_out}<|im_end|>"
target_ids = tokenizer.encode(target_out)
input_source_ids = input_ids + source_ids
input_ids += source_ids + target_ids
labels += source_mask + target_ids
fbank_mask += fbank_mask_i
if len(speech) > 0:
fbank.append(speech[0, :, :])
fbank_lens.append(speech_lengths)
input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length]
attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length]
fbank_mask = torch.tensor(fbank_mask, dtype=torch.float32)
fbank_beg = torch.tensor(fbank_beg, dtype=torch.int32)
fake_token_len = torch.tensor(fake_token_len, dtype=torch.int32)
source_ids = torch.tensor(input_source_ids, dtype=torch.int64)
target_ids = torch.tensor(target_ids, dtype=torch.int64)
if len(fbank) > 0:
speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0)
speech_lengths = torch.nn.utils.rnn.pad_sequence(
fbank_lens, batch_first=True, padding_value=-1
)
else:
speech = []
speech_lengths = []
output = {
"speech": speech,
"speech_lengths": speech_lengths,
"fbank_mask": fbank_mask[None, :],
"fbank_beg": fbank_beg[None,],
"fake_token_len": fake_token_len[None, :],
"input_ids": input_ids[None,],
"attention_mask": attention_mask[None,],
"labels_ids": labels,
"source_ids": source_ids[None, :],
"target_ids": target_ids[None, :],
}
return output
def inference_prepare(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
meta_data = {}
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
contents = self.data_template(data_in[0])
output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs)
batch = to_device(output, kwargs["device"])
# audio encoder
speech = batch["speech"]
if len(speech) > 0:
if "audio_embedding" in kwargs and "audio_embedding_lens" in kwargs:
encoder_out = kwargs["audio_embedding"]
encoder_out_lens = kwargs["audio_embedding_lens"]
else:
speech_lengths = batch["speech_lengths"][:, 0]
# fp16
if kwargs.get("fp16", False):
speech = speech.to(torch.float16)
elif kwargs.get("bf16", False):
speech = speech.to(torch.bfloat16)
# audio encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# audio_adaptor
adaptor_out, adaptor_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
meta_data["encoder_out"] = encoder_out
meta_data["encoder_out_lens"] = encoder_out_lens
meta_data["audio_adaptor_out"] = adaptor_out
meta_data["audio_adaptor_out_lens"] = adaptor_out_lens
input_ids = batch["input_ids"]
source_ids = batch["source_ids"]
fbank_beg = batch["fbank_beg"]
fake_token_len = batch["fake_token_len"]
if not kwargs.get("teacherforcing", False):
input_ids = source_ids
input_ids[input_ids < 0] = 0
inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
batch_size, token_num, dims = inputs_embeds.shape
fake_token_len[fake_token_len < 0] = 0
fbank_beg[fbank_beg < 0] = 0
speech_idx = 0
for batch_idx in range(batch_size):
for turn_id in range(fbank_beg.shape[1]):
fbank_beg_idx = fbank_beg[batch_idx, turn_id].item()
if fbank_beg_idx > 0:
speech_token_len = fake_token_len[batch_idx, turn_id]
speech_token = adaptor_out[speech_idx, :speech_token_len, :]
try:
inputs_embeds[
batch_idx,
fbank_beg_idx : fbank_beg_idx + speech_token_len,
:,
] = speech_token
except Exception as e:
#
logging.error(f"{str(e)}, {traceback.format_exc()}")
logging.info(
f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, adaptor_out: {adaptor_out.shape}, adaptor_out_lens: {adaptor_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}"
)
speech_token_len = adaptor_out_lens[speech_idx].item()
speech_token = adaptor_out[speech_idx, :speech_token_len, :]
inputs_embeds[
batch_idx,
fbank_beg_idx : fbank_beg_idx + speech_token_len,
:,
] = speech_token
speech_idx += 1
return inputs_embeds, contents, batch, source_ids, meta_data
def get_prompt(self, hotwords: list[str], language: str = None, itn: bool = True):
if len(hotwords) > 0:
hotwords = ", ".join(hotwords)
prompt = f"请结合上下文信息,更加准确地完成语音转写任务。如果没有相关信息,我们会留空。\n\n\n**上下文信息:**\n\n\n"
prompt += f"热词列表:[{hotwords}]\n"
else:
prompt = ""
if language is None:
prompt += "语音转写"
else:
prompt += f"语音转写成{language}"
if not itn:
prompt += ",不进行文本规整"
return prompt + ""
def generate_chatml(self, prompt: str, data: Union[str, torch.Tensor]):
if isinstance(data, str):
return [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": f"{prompt}<|startofspeech|>!{data}<|endofspeech|>"},
{"role": "assistant", "content": "null"},
]
elif isinstance(data, torch.Tensor):
return [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": f"{prompt}<|startofspeech|>!!<|endofspeech|>",
"audio": data,
},
{"role": "assistant", "content": "null"},
]
def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
prompt = self.get_prompt(
kwargs.get("hotwords", []), kwargs.get("language", None), kwargs.get("itn", True)
)
data_in = [self.generate_chatml(prompt, data) for data in data_in]
if key is None:
key = []
for _ in data_in:
chars = string.ascii_letters + string.digits
key.append("rand_key_" + "".join(random.choice(chars) for _ in range(13)))
return self.inference_llm(
data_in,
data_lengths=data_lengths,
key=key,
tokenizer=tokenizer,
frontend=frontend,
**kwargs,
)
def inference_llm(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
data_in, data_lengths, key, tokenizer, frontend, **kwargs
)
ctc_results = []
if self.ctc_decoder is not None:
encoder_out = meta_data["encoder_out"]
encoder_out_lens = meta_data["encoder_out_lens"]
decoder_out, decoder_out_lens = self.ctc_decoder(encoder_out, encoder_out_lens)
ctc_logits = self.ctc.log_softmax(decoder_out)
b, n, d = encoder_out.size()
if isinstance(key[0], (list, tuple)):
key = key[0]
if len(key) < b:
key = key * b
for i in range(b):
x = ctc_logits[i, : encoder_out_lens[i].item(), :]
yseq = x.argmax(dim=-1)
yseq = torch.unique_consecutive(yseq, dim=-1)
mask = yseq != self.blank_id
token_int = yseq[mask].tolist()
# Change integer-ids to tokens
text = self.ctc_tokenizer.decode(token_int)
ctc_results.append({"key": key[i], "text": text, "ctc_logits": x})
llm_dtype = kwargs.get("llm_dtype", "fp32")
if llm_dtype == "fp32":
llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype
llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype
device_type = torch.device(kwargs.get("device", "cuda")).type
with torch.autocast(
device_type=device_type if device_type in ["cuda", "xpu", "mps"] else "cpu",
enabled=True if llm_dtype != "fp32" else False,
dtype=dtype_map[llm_dtype],
):
label = contents["assistant"][-1]
self.llm = self.llm.to(dtype_map[llm_dtype])
inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
llm_kwargs = kwargs.get("llm_kwargs", {})
if not kwargs.get("teacherforcing", False):
attention_mask = batch.get("attention_mask", None)
generated_ids = self.llm.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
max_new_tokens=kwargs.get("max_length", 512),
pad_token_id=self.llm.config.pad_token_id or self.llm.config.eos_token_id,
**llm_kwargs,
)
response = tokenizer.batch_decode(
generated_ids,
skip_special_tokens=kwargs.get("skip_special_tokens", True),
)[0]
loss = None
else:
labels_ids = batch["labels_ids"]
labels_ids[labels_ids == -1] = -100
attention_mask = batch.get("attention_mask", None)
model_outputs = self.llm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels_ids,
pad_token_id=self.llm.config.pad_token_id or self.llm.config.eos_token_id,
**llm_kwargs,
)
preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :]
response = tokenizer.batch_decode(
preds,
add_special_tokens=False,
skip_special_tokens=kwargs.get("skip_special_tokens", True),
)[0]
loss = model_outputs.loss.item()
response = kwargs.get("prev_text", "") + response
ibest_writer = None
if kwargs.get("output_dir") is not None:
if not hasattr(self, "writer"):
self.writer = DatadirWriter(kwargs.get("output_dir"))
ibest_writer = self.writer[f"{0 + 1}best_recog"]
results = []
response_clean = re.sub(r"[^\w\s\u3000\u4e00-\u9fff]+", "", response)
result_i = {
"key": key[0],
"text": re.sub(r"\s+", " ", response.replace("/sil", " ")),
"text_tn": response_clean,
"label": label,
}
if loss is not None:
result_i["loss"] = loss
results.append(result_i)
for ctc_result, result in zip(ctc_results, results):
result["ctc_text"] = ctc_result["text"].replace("<|nospeech|>", "")
target_ids = torch.tensor(
self.ctc_tokenizer.encode(result["ctc_text"]), dtype=torch.int64
)
result["ctc_timestamps"] = forced_align(
ctc_result["ctc_logits"], target_ids, self.blank_id
)
target_ids = torch.tensor(self.ctc_tokenizer.encode(result["text"]), dtype=torch.int64)
result["timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id)
for timestamps in [result["timestamps"], result["ctc_timestamps"]]:
for timestamp in timestamps:
timestamp["token"] = self.ctc_tokenizer.decode([timestamp["token"]])
timestamp["start_time"] = timestamp["start_time"] * 6 * 10 / 1000
timestamp["end_time"] = timestamp["end_time"] * 6 * 10 / 1000
if ibest_writer is not None:
ibest_writer["text"][key[0]] = response.replace("\n", " ")
ibest_writer["label"][key[0]] = label.replace("\n", " ")
ibest_writer["text_tn"][key[0]] = response_clean
return results, meta_data
@staticmethod
def from_pretrained(model: str = None, **kwargs):
from funasr import AutoModel
model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs)
return model, kwargs

View File

@@ -0,0 +1,546 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import torch
from funasr.register import tables
from funasr.models.transformer.utils.nets_utils import make_pad_mask
class mae_loss(torch.nn.Module):
def __init__(self, normalize_length=False):
super(mae_loss, self).__init__()
self.normalize_length = normalize_length
self.criterion = torch.nn.L1Loss(reduction="sum")
def forward(self, token_length, pre_token_length):
loss_token_normalizer = token_length.size(0)
if self.normalize_length:
loss_token_normalizer = token_length.sum().type(torch.float32)
loss = self.criterion(token_length, pre_token_length)
loss = loss / loss_token_normalizer
return loss
def cif(hidden, alphas, threshold):
batch_size, len_time, hidden_size = hidden.size()
# loop varss
integrate = torch.zeros([batch_size], device=hidden.device)
frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
# intermediate vars along time
list_fires = []
list_frames = []
for t in range(len_time):
alpha = alphas[:, t]
distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate
integrate += alpha
list_fires.append(integrate)
fire_place = integrate >= threshold
integrate = torch.where(
fire_place, integrate - torch.ones([batch_size], device=hidden.device), integrate
)
cur = torch.where(fire_place, distribution_completion, alpha)
remainds = alpha - cur
frame += cur[:, None] * hidden[:, t, :]
list_frames.append(frame)
frame = torch.where(
fire_place[:, None].repeat(1, hidden_size), remainds[:, None] * hidden[:, t, :], frame
)
fires = torch.stack(list_fires, 1)
frames = torch.stack(list_frames, 1)
list_ls = []
len_labels = torch.round(alphas.sum(-1)).int()
max_label_len = len_labels.max()
for b in range(batch_size):
fire = fires[b, :]
l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
pad_l = torch.zeros([max_label_len - l.size(0), hidden_size], device=hidden.device)
list_ls.append(torch.cat([l, pad_l], 0))
return torch.stack(list_ls, 0), fires
def cif_wo_hidden(alphas, threshold):
batch_size, len_time = alphas.size()
# loop varss
integrate = torch.zeros([batch_size], device=alphas.device)
# intermediate vars along time
list_fires = []
for t in range(len_time):
alpha = alphas[:, t]
integrate += alpha
list_fires.append(integrate)
fire_place = integrate >= threshold
integrate = torch.where(
fire_place,
integrate - torch.ones([batch_size], device=alphas.device) * threshold,
integrate,
)
fires = torch.stack(list_fires, 1)
return fires
@tables.register("predictor_classes", "CifPredictorV3")
class CifPredictorV3(torch.nn.Module):
def __init__(
self,
idim,
l_order,
r_order,
threshold=1.0,
dropout=0.1,
smooth_factor=1.0,
noise_threshold=0,
tail_threshold=0.0,
tf2torch_tensor_name_prefix_torch="predictor",
tf2torch_tensor_name_prefix_tf="seq2seq/cif",
smooth_factor2=1.0,
noise_threshold2=0,
upsample_times=5,
upsample_type="cnn",
use_cif1_cnn=True,
tail_mask=True,
):
super(CifPredictorV3, self).__init__()
self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1)
self.cif_output = torch.nn.Linear(idim, 1)
self.dropout = torch.nn.Dropout(p=dropout)
self.threshold = threshold
self.smooth_factor = smooth_factor
self.noise_threshold = noise_threshold
self.tail_threshold = tail_threshold
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
self.upsample_times = upsample_times
self.upsample_type = upsample_type
self.use_cif1_cnn = use_cif1_cnn
if torch.cuda.get_device_name() == 'R200-8F' and self.upsample_type != "cnn":
# kunlunxin doesn't support some ops in other two branches
self.upsample_type = "cnn"
if self.upsample_type == "cnn":
self.upsample_cnn = torch.nn.ConvTranspose1d(
idim, idim, self.upsample_times, self.upsample_times
)
self.cif_output2 = torch.nn.Linear(idim, 1)
elif self.upsample_type == "cnn_blstm":
self.upsample_cnn = torch.nn.ConvTranspose1d(
idim, idim, self.upsample_times, self.upsample_times
)
self.blstm = torch.nn.LSTM(
idim, idim, 1, bias=True, batch_first=True, dropout=0.0, bidirectional=True
)
self.cif_output2 = torch.nn.Linear(idim * 2, 1)
elif self.upsample_type == "cnn_attn":
self.upsample_cnn = torch.nn.ConvTranspose1d(
idim, idim, self.upsample_times, self.upsample_times
)
from funasr.models.transformer.encoder import EncoderLayer as TransformerEncoderLayer
from funasr.models.transformer.attention import MultiHeadedAttention
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
positionwise_layer_args = (
idim,
idim * 2,
0.1,
)
self.self_attn = TransformerEncoderLayer(
idim,
MultiHeadedAttention(4, idim, 0.1),
PositionwiseFeedForward(*positionwise_layer_args),
0.1,
True, # normalize_before,
False, # concat_after,
)
self.cif_output2 = torch.nn.Linear(idim, 1)
self.smooth_factor2 = smooth_factor2
self.noise_threshold2 = noise_threshold2
def forward(
self,
hidden,
target_label=None,
mask=None,
ignore_id=-1,
mask_chunk_predictor=None,
target_label_length=None,
):
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
output = torch.relu(self.cif_conv1d(queries))
# alphas2 is an extra head for timestamp prediction
if not self.use_cif1_cnn:
_output = context
else:
_output = output
if self.upsample_type == "cnn":
output2 = self.upsample_cnn(_output)
output2 = output2.transpose(1, 2)
elif self.upsample_type == "cnn_blstm":
output2 = self.upsample_cnn(_output)
output2 = output2.transpose(1, 2)
output2, (_, _) = self.blstm(output2)
elif self.upsample_type == "cnn_attn":
output2 = self.upsample_cnn(_output)
output2 = output2.transpose(1, 2)
output2, _ = self.self_attn(output2, mask)
alphas2 = torch.sigmoid(self.cif_output2(output2))
alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
# repeat the mask in T demension to match the upsampled length
if mask is not None:
mask2 = (
mask.repeat(1, self.upsample_times, 1)
.transpose(-1, -2)
.reshape(alphas2.shape[0], -1)
)
mask2 = mask2.unsqueeze(-1)
alphas2 = alphas2 * mask2
alphas2 = alphas2.squeeze(-1)
token_num2 = alphas2.sum(-1)
output = output.transpose(1, 2)
output = self.cif_output(output)
alphas = torch.sigmoid(output)
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
if mask is not None:
mask = mask.transpose(-1, -2).float()
alphas = alphas * mask
if mask_chunk_predictor is not None:
alphas = alphas * mask_chunk_predictor
alphas = alphas.squeeze(-1)
mask = mask.squeeze(-1)
if target_label_length is not None:
target_length = target_label_length
elif target_label is not None:
target_length = (target_label != ignore_id).float().sum(-1)
else:
target_length = None
token_num = alphas.sum(-1)
if target_length is not None:
alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
elif self.tail_threshold > 0.0:
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
if target_length is None and self.tail_threshold > 0.0:
token_num_int = torch.max(token_num).type(torch.int32).item()
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
return acoustic_embeds, token_num, alphas, cif_peak, token_num2
def get_upsample_timestamp(self, hidden, mask=None, token_num=None):
h = hidden
b = hidden.shape[0]
context = h.transpose(1, 2)
queries = self.pad(context)
output = torch.relu(self.cif_conv1d(queries))
# alphas2 is an extra head for timestamp prediction
if not self.use_cif1_cnn:
_output = context
else:
_output = output
if self.upsample_type == "cnn":
output2 = self.upsample_cnn(_output)
output2 = output2.transpose(1, 2)
elif self.upsample_type == "cnn_blstm":
output2 = self.upsample_cnn(_output)
output2 = output2.transpose(1, 2)
output2, (_, _) = self.blstm(output2)
elif self.upsample_type == "cnn_attn":
output2 = self.upsample_cnn(_output)
output2 = output2.transpose(1, 2)
output2, _ = self.self_attn(output2, mask)
alphas2 = torch.sigmoid(self.cif_output2(output2))
alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
# repeat the mask in T demension to match the upsampled length
if mask is not None:
mask2 = (
mask.repeat(1, self.upsample_times, 1)
.transpose(-1, -2)
.reshape(alphas2.shape[0], -1)
)
mask2 = mask2.unsqueeze(-1)
alphas2 = alphas2 * mask2
alphas2 = alphas2.squeeze(-1)
_token_num = alphas2.sum(-1)
if token_num is not None:
alphas2 *= (token_num / _token_num)[:, None].repeat(1, alphas2.size(1))
# re-downsample
ds_alphas = alphas2.reshape(b, -1, self.upsample_times).sum(-1)
ds_cif_peak = cif_wo_hidden(ds_alphas, self.threshold - 1e-4)
# upsampled alphas and cif_peak
us_alphas = alphas2
us_cif_peak = cif_wo_hidden(us_alphas, self.threshold - 1e-4)
return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
b, t, d = hidden.size()
tail_threshold = self.tail_threshold
if mask is not None:
zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
ones_t = torch.ones_like(zeros_t)
mask_1 = torch.cat([mask, zeros_t], dim=1)
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
alphas = torch.cat([alphas, zeros_t], dim=1)
alphas = torch.add(alphas, tail_threshold)
else:
tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
tail_threshold = torch.reshape(tail_threshold, (1, 1))
alphas = torch.cat([alphas, tail_threshold], dim=1)
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
hidden = torch.cat([hidden, zeros], dim=1)
token_num = alphas.sum(dim=-1)
token_num_floor = torch.floor(token_num)
return hidden, alphas, token_num_floor
def gen_frame_alignments(
self, alphas: torch.Tensor = None, encoder_sequence_length: torch.Tensor = None
):
batch_size, maximum_length = alphas.size()
int_type = torch.int32
is_training = self.training
if is_training:
token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
else:
token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
max_token_num = torch.max(token_num).item()
alphas_cumsum = torch.cumsum(alphas, dim=1)
alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
index = torch.ones([batch_size, max_token_num], dtype=int_type)
index = torch.cumsum(index, dim=1)
index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
index_div_bool_zeros = index_div.eq(0)
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
index_div_bool_zeros_count = torch.clamp(
index_div_bool_zeros_count, 0, encoder_sequence_length.max()
)
token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
index_div_bool_zeros_count *= token_num_mask
index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(
1, 1, maximum_length
)
ones = torch.ones_like(index_div_bool_zeros_count_tile)
zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
ones = torch.cumsum(ones, dim=2)
cond = index_div_bool_zeros_count_tile == ones
index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
predictor_mask = (
(~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max()))
.type(int_type)
.to(encoder_sequence_length.device)
)
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
predictor_alignments = index_div_bool_zeros_count_tile_out
predictor_alignments_length = predictor_alignments.sum(-1).type(
encoder_sequence_length.dtype
)
return predictor_alignments.detach(), predictor_alignments_length.detach()
@tables.register("predictor_classes", "CifPredictorV3Export")
class CifPredictorV3Export(torch.nn.Module):
def __init__(self, model, **kwargs):
super().__init__()
self.pad = model.pad
self.cif_conv1d = model.cif_conv1d
self.cif_output = model.cif_output
self.threshold = model.threshold
self.smooth_factor = model.smooth_factor
self.noise_threshold = model.noise_threshold
self.tail_threshold = model.tail_threshold
self.upsample_times = model.upsample_times
self.upsample_cnn = model.upsample_cnn
self.blstm = model.blstm
self.cif_output2 = model.cif_output2
self.smooth_factor2 = model.smooth_factor2
self.noise_threshold2 = model.noise_threshold2
def forward(
self,
hidden: torch.Tensor,
mask: torch.Tensor,
):
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
output = torch.relu(self.cif_conv1d(queries))
output = output.transpose(1, 2)
output = self.cif_output(output)
alphas = torch.sigmoid(output)
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
mask = mask.transpose(-1, -2).float()
alphas = alphas * mask
alphas = alphas.squeeze(-1)
token_num = alphas.sum(-1)
mask = mask.squeeze(-1)
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
acoustic_embeds, cif_peak = cif_export(hidden, alphas, self.threshold)
return acoustic_embeds, token_num, alphas, cif_peak
def get_upsample_timestmap(self, hidden, mask=None, token_num=None):
h = hidden
b = hidden.shape[0]
context = h.transpose(1, 2)
# generate alphas2
_output = context
output2 = self.upsample_cnn(_output)
output2 = output2.transpose(1, 2)
output2, (_, _) = self.blstm(output2)
alphas2 = torch.sigmoid(self.cif_output2(output2))
alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
mask = (
mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
)
mask = mask.unsqueeze(-1)
alphas2 = alphas2 * mask
alphas2 = alphas2.squeeze(-1)
_token_num = alphas2.sum(-1)
alphas2 *= (token_num / _token_num)[:, None].repeat(1, alphas2.size(1))
# upsampled alphas and cif_peak
us_alphas = alphas2
us_cif_peak = cif_wo_hidden_export(us_alphas, self.threshold - 1e-4)
return us_alphas, us_cif_peak
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
b, t, d = hidden.size()
tail_threshold = self.tail_threshold
zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
ones_t = torch.ones_like(zeros_t)
mask_1 = torch.cat([mask, zeros_t], dim=1)
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
alphas = torch.cat([alphas, zeros_t], dim=1)
alphas = torch.add(alphas, tail_threshold)
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
hidden = torch.cat([hidden, zeros], dim=1)
token_num = alphas.sum(dim=-1)
token_num_floor = torch.floor(token_num)
return hidden, alphas, token_num_floor
@torch.jit.script
def cif_export(hidden, alphas, threshold: float):
batch_size, len_time, hidden_size = hidden.size()
threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
# loop varss
integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=hidden.device)
frame = torch.zeros([batch_size, hidden_size], dtype=hidden.dtype, device=hidden.device)
# intermediate vars along time
list_fires = []
list_frames = []
for t in range(len_time):
alpha = alphas[:, t]
distribution_completion = (
torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device) - integrate
)
integrate += alpha
list_fires.append(integrate)
fire_place = integrate >= threshold
integrate = torch.where(
fire_place,
integrate - torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device),
integrate,
)
cur = torch.where(fire_place, distribution_completion, alpha)
remainds = alpha - cur
frame += cur[:, None] * hidden[:, t, :]
list_frames.append(frame)
frame = torch.where(
fire_place[:, None].repeat(1, hidden_size), remainds[:, None] * hidden[:, t, :], frame
)
fires = torch.stack(list_fires, 1)
frames = torch.stack(list_frames, 1)
fire_idxs = fires >= threshold
frame_fires = torch.zeros_like(hidden)
max_label_len = frames[0, fire_idxs[0]].size(0)
for b in range(batch_size):
frame_fire = frames[b, fire_idxs[b]]
frame_len = frame_fire.size(0)
frame_fires[b, :frame_len, :] = frame_fire
if frame_len >= max_label_len:
max_label_len = frame_len
frame_fires = frame_fires[:, :max_label_len, :]
return frame_fires, fires
@torch.jit.script
def cif_wo_hidden_export(alphas, threshold: float):
batch_size, len_time = alphas.size()
# loop varss
integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=alphas.device)
# intermediate vars along time
list_fires = []
for t in range(len_time):
alpha = alphas[:, t]
integrate += alpha
list_fires.append(integrate)
fire_place = integrate >= threshold
integrate = torch.where(
fire_place,
integrate - torch.ones([batch_size], device=alphas.device) * threshold,
integrate,
)
fires = torch.stack(list_fires, 1)
return fires

View File

@@ -9,3 +9,10 @@ ruamel.yaml
nltk==3.7
pynini==2.1.6
soundfile
transformers>=4.51.3
funasr>=1.3.0
zhconv
whisper_normalizer
pyopenjtalk-plus
compute-wer
openai-whisper