initial commit

This commit is contained in:
2026-04-08 06:41:00 +00:00
commit 1385a6f46b
23 changed files with 2831 additions and 0 deletions

View File

@@ -0,0 +1,27 @@
FROM corex:4.3.8
WORKDIR /root
RUN set -eux; \
# 1) 把 aliyun 源替换成官方源(避免 403
sed -i -E 's|http://mirrors\.aliyun\.com/ubuntu|http://archive.ubuntu.com/ubuntu|g' /etc/apt/sources.list; \
sed -i -E 's|http://mirrors\.aliyun\.com/ubuntu|http://archive.ubuntu.com/ubuntu|g' /etc/apt/sources.list.d/*.list 2>/dev/null || true; \
\
# 2) 更新并安装
apt-get update; \
apt-get install -y --no-install-recommends vim net-tools ca-certificates; \
rm -rf /var/lib/apt/lists/*
ADD . /root/
COPY requirements.txt /root
RUN pip install -r /root/requirements.txt -i https://nexus.4pd.io/repository/pypi-all/simple --extra-index-url https://mirror.sjtu.edu.cn/pypi/web/simple
RUN pip install funasr==1.3.1 'transformers>=4.51.3' openai-whisper \
-i https://nexus.4pd.io/repository/pypi-all/simple --extra-index-url https://mirror.sjtu.edu.cn/pypi/web/simple
# Patch files
COPY ./replaced_files/bi_v150/cif_predictor.py /usr/local/lib/python3.10/site-packages/funasr/models/paraformer/
COPY ./replaced_files/funasr_nano_model.py /usr/local/lib/python3.10/site-packages/funasr/models/fun_asr_nano/model.py
ENTRYPOINT ["python3"]
CMD ["main.py"]

28
funasr/README.md Normal file
View File

@@ -0,0 +1,28 @@
# 天数智芯 天垓150 ASRFunASR架构
## 镜像构造
```shell
docker build -f ./Dockerfile.funasr-bi150 -t <your_image> .
```
其中,基础镜像 corex:4.3.8 通过联系天数智芯智铠100厂商技术支持可获取
## 使用说明
### 使用 FastAPI 启动ASR服务
例如:
```shell
docker run -dit -v /usr/src:/usr/src -v /lib/modules:/lib/modules --device=/dev/iluvatar0:/dev/iluvatar0 \
-v /mnt/contest_ceph/leaderboard/modelHubXC/iic/SenseVoiceSmall:/model \
--network=host <your_image> \
main.py --model_dir /model --model_type sensevoice --use_gpu --port 1111
```
具体参数代码设定可参考代码文件
### 测试ASR服务
项目根路径`sample_data`目录下附带上了中文的测试音频和附带内容
```shell
curl -X POST http://localhost:1111/transduce \
-F "audio=@../sample_data/lei-jun-test.wav" \
-F "lang=zh"
```

271
funasr/fastapi_funasr.py Normal file
View File

@@ -0,0 +1,271 @@
import os
import time
import argparse
import torchaudio
import torch
import traceback
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Form
import uuid
import uvicorn
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
# from funasr.models.fun_asr_nano.model import FunASRNano
os.makedirs("./input", exist_ok=True)
status = "Running"
model = None
device = ""
app = FastAPI()
CUSTOM_DEVICE = os.getenv("CUSTOM_DEVICE", "")
if CUSTOM_DEVICE.startswith("mlu"):
import torch_mlu
elif CUSTOM_DEVICE.startswith("ascend"):
import torch_npu
elif CUSTOM_DEVICE.startswith("pt"):
import torch_dipu
def make_all_dense(module: torch.nn.Module):
for name, param in list(module.named_parameters(recurse=True)):
if getattr(param, "is_sparse", False) and param.is_sparse:
with torch.no_grad():
dense = param.to_dense().contiguous()
parent = module
*mods, leaf = name.split(".")
for m in mods:
parent = getattr(parent, m)
setattr(parent, leaf, torch.nn.Parameter(dense, requires_grad=param.requires_grad))
# 处理 buffer如 running_mean 等)
for name, buf in list(module.named_buffers(recurse=True)):
# PyTorch 稀疏张量 layout 不是 strided
if buf.layout != torch.strided:
dense = buf.to_dense().contiguous()
parent = module
*mods, leaf = name.split(".")
for m in mods:
parent = getattr(parent, m)
parent.register_buffer(leaf, dense, persistent=True)
def split_audio(waveform, sample_rate, segment_seconds=20):
segment_samples = segment_seconds * sample_rate
segments = []
for i in range(0, waveform.shape[1], segment_samples):
segment = waveform[:, i:i + segment_samples]
if segment.shape[1] > 0:
segments.append(segment)
return segments
# def determine_model_type(model_name):
# if "sensevoice" in model_name.lower():
# return "sensevoice"
# elif "whisper" in model_name.lower():
# return "whisper"
# elif "paraformer" in model_name.lower():
# return "paraformer"
# elif "conformer" in model_name.lower():
# return "conformer"
# elif "uniasr" in model_name.lower():
# return "uni_asr"
# else:
# return "unknown"
@app.on_event("startup")
def load_model():
global status, model, device
config = app.state.config
use_gpu = config.get("use_gpu", True)
model_dir = config.get("model_dir", "/model")
model_type = config.get("model_type", "sensevoice")
warmup = config.get("warmup", False)
print(">> Startup config:")
print(" model_dir =", model_dir, flush=True)
print(" model_type =", model_type, flush=True)
print(" use_gpu =", use_gpu, flush=True)
print(" warmup =", warmup, flush=True)
device = "cpu"
if use_gpu:
if CUSTOM_DEVICE.startswith("mlu"):
device = "mlu:0"
elif CUSTOM_DEVICE.startswith("ascend"):
device = "npu:0"
else:
device = "cuda:0"
# 针对加速卡的特殊处理部分
if device == "cuda:0" and torch.cuda.get_device_name() == "Iluvatar BI-V100" and model_type == "whisper":
# 天垓100情况下的Whisper需要绕过不支持算子
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
print(f"device: {device}", flush=True)
dense_convert = False
if device == "cuda:0" and CUSTOM_DEVICE.startswith("pt") and model_type == "whisper":
dense_convert = True
if device.startswith("npu") and model_type == "whisper":
# Ascend NPU 加载whisper的部分会有Sparse部分device不匹配
dense_convert = True
print(f"dense_convert: {dense_convert}", flush=True)
if dense_convert:
model = AutoModel(
model=model_dir,
vad_model=None,
disable_update=True,
device="cpu"
)
make_all_dense(model.model)
model.model.to(dtype=torch.float32, memory_format=torch.contiguous_format)
model.model.to(device)
model.kwargs["device"] = device
else:
# 不使用VAD, punctspk模型就测试原始ASR能力
model = AutoModel(
model=model_dir,
# vad_model="fsmn-vad",
# vad_kwargs={"max_single_segment_time": 30000},
vad_model=None,
device=device,
disable_update=True
)
if device.startswith("npu") or warmup:
# Ascend NPU由于底层设计的不同初始化卡的调度比其他卡更复杂要先进行warmup
print("Start warmup...", flush=True)
res = model.generate(input="warmup.wav")
print("warmup complete.", flush=True)
status = "Success"
def test_funasr(audio_file, lang):
# 推理部分
waveform, sample_rate = torchaudio.load(audio_file)
# print(waveform.shape)
duration = waveform.shape[1] / sample_rate
segments = split_audio(waveform, sample_rate, segment_seconds=20)
generated_text = ""
processing_time = 0
model_type = app.state.config.get("model_type", "sensevoice")
if model_type == "uni_asr":
# uni_asr比较特殊设计就是处理长音频的自带VAD切分的话前20s如果几乎没有人讲话全是音乐直接会报错
# 因为可能会被切掉所有音频导致实际编解码输入为0
ts1 = time.time()
res = model.generate(
input=audio_file
)
generated_text = res[0]["text"]
ts2 = time.time()
processing_time = ts2 - ts1
else:
# 按照切分的音频依次输入
for i, segment in enumerate(segments):
segment_path = f"temp_seg_{i}.wav"
torchaudio.save(segment_path, segment, sample_rate)
ts1 = time.time()
text = None
if model_type == "sensevoice":
res = model.generate(
input=segment_path,
cache={},
language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
use_itn=True,
batch_size_s=60,
merge_vad=False,
# merge_length_s=15,
)
text = rich_transcription_postprocess(res[0]["text"])
elif model_type == "whisper":
DecodingOptions = {
"task": "transcribe",
"language": lang,
"beam_size": None,
"fp16": False,
"without_timestamps": False,
"prompt": None,
}
res = model.generate(
DecodingOptions=DecodingOptions,
input=segment_path,
batch_size_s=0,
)
text = res[0]["text"]
elif model_type == "paraformer":
res = model.generate(
input=segment_path,
batch_size_s=300
)
text = res[0]["text"]
# paraformer模型会一个字一个字输出中间夹太多空格会影响1-cer的结果
if lang == "zh":
text = text.replace(" ", "")
elif model_type == "conformer":
res = model.generate(
input=segment_path,
batch_size_s=300
)
text = res[0]["text"]
# elif model_type == "uni_asr":
# if i == 0:
# os.remove(segment_path)
# continue
# res = model.generate(
# input=segment_path
# )
# text = res[0]["text"]
else:
raise RuntimeError("unknown model type")
if text is not None:
# some models output "▁" (9601, Unicode U+2581) as separator between words, replace them with space for better readability
text = text.replace("_", " ")
text = text.replace(chr(9601), " ")
ts2 = time.time()
generated_text += text
processing_time += (ts2 - ts1)
os.remove(segment_path)
rtf = processing_time / duration
print("Text:", generated_text, flush=True)
print(f"Audio duration:\t{duration:.3f} s", flush=True)
print(f"Elapsed:\t{processing_time:.3f} s", flush=True)
print(f"RTF = {processing_time:.3f}/{duration:.3f} = {rtf:.3f}", flush=True)
return generated_text
@app.get("/health")
def health():
if status=="Running":
return {
"status":"loading model"
}
ret = {
"status": "ok" if status == "Success" else "failed",
}
return ret
@app.post("/transduce")
def transduce(
audio: UploadFile = File(...),
lang: str = Form("zh"),
background_tasks: BackgroundTasks = None
):
try:
file_path = f"./input/{uuid.uuid4()}.wav"
with open(file_path, "wb") as f:
f.write(audio.file.read())
background_tasks.add_task(os.remove, file_path)
generated_text = test_funasr(file_path, lang)
return {"generated_text": generated_text}
except Exception:
raise HTTPException(status_code=500, detail=f"Processing failed: \n{traceback.format_exc()}")
# if __name__ == "__main__":
# uvicorn.run("fastapi_funasr:app", host="0.0.0.0", port=1111, workers=1)

27
funasr/main.py Normal file
View File

@@ -0,0 +1,27 @@
import argparse
import uvicorn
from fastapi_funasr import app
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", type=str, default="/model", help="model directory")
parser.add_argument("--model_type", type=str, default="sensevoice", help="model type, e.g. sensevoice")
parser.add_argument("--use_gpu", action="store_true", default=True)
parser.add_argument("--warmup", action="store_true", help="whether do warmup when first initializing model")
parser.add_argument("--port", type=int, default=8000, help="service port")
args = parser.parse_args()
# 将参数加到 app.state 中
app.state.config = {
"model_dir": args.model_dir,
"model_type": args.model_type,
"use_gpu": args.use_gpu, # True
"warmup": args.warmup
}
uvicorn.run("fastapi_funasr:app",
host="0.0.0.0",
port=args.port,
workers=1
)

View File

@@ -0,0 +1,762 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import torch
import logging
import numpy as np
from funasr.register import tables
from funasr.train_utils.device_funcs import to_device
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from torch.cuda.amp import autocast
@tables.register("predictor_classes", "CifPredictor")
class CifPredictor(torch.nn.Module):
def __init__(
self,
idim,
l_order,
r_order,
threshold=1.0,
dropout=0.1,
smooth_factor=1.0,
noise_threshold=0,
tail_threshold=0.45,
):
super().__init__()
self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
self.cif_output = torch.nn.Linear(idim, 1)
self.dropout = torch.nn.Dropout(p=dropout)
self.threshold = threshold
self.smooth_factor = smooth_factor
self.noise_threshold = noise_threshold
self.tail_threshold = tail_threshold
def forward(
self,
hidden,
target_label=None,
mask=None,
ignore_id=-1,
mask_chunk_predictor=None,
target_label_length=None,
):
with autocast(False):
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
memory = self.cif_conv1d(queries)
output = memory + context
output = self.dropout(output)
output = output.transpose(1, 2)
output = torch.relu(output)
output = self.cif_output(output)
alphas = torch.sigmoid(output)
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
if mask is not None:
mask = mask.transpose(-1, -2).float()
alphas = alphas * mask
if mask_chunk_predictor is not None:
alphas = alphas * mask_chunk_predictor
alphas = alphas.squeeze(-1)
mask = mask.squeeze(-1)
if target_label_length is not None:
target_length = target_label_length
elif target_label is not None:
target_length = (target_label != ignore_id).float().sum(-1)
else:
target_length = None
token_num = alphas.sum(-1)
if target_length is not None:
alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
elif self.tail_threshold > 0.0:
hidden, alphas, token_num = self.tail_process_fn(
hidden, alphas, token_num, mask=mask
)
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
if target_length is None and self.tail_threshold > 0.0:
token_num_int = torch.max(token_num).type(torch.int32).item()
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
return acoustic_embeds, token_num, alphas, cif_peak
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
b, t, d = hidden.size()
tail_threshold = self.tail_threshold
if mask is not None:
zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
ones_t = torch.ones_like(zeros_t)
mask_1 = torch.cat([mask, zeros_t], dim=1)
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
alphas = torch.cat([alphas, zeros_t], dim=1)
alphas = torch.add(alphas, tail_threshold)
else:
tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
tail_threshold = torch.reshape(tail_threshold, (1, 1))
alphas = torch.cat([alphas, tail_threshold], dim=1)
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
hidden = torch.cat([hidden, zeros], dim=1)
token_num = alphas.sum(dim=-1)
token_num_floor = torch.floor(token_num)
return hidden, alphas, token_num_floor
def gen_frame_alignments(
self, alphas: torch.Tensor = None, encoder_sequence_length: torch.Tensor = None
):
batch_size, maximum_length = alphas.size()
int_type = torch.int32
is_training = self.training
if is_training:
token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
else:
token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
max_token_num = torch.max(token_num).item()
alphas_cumsum = torch.cumsum(alphas, dim=1)
alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
index = torch.ones([batch_size, max_token_num], dtype=int_type)
index = torch.cumsum(index, dim=1)
index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
index_div_bool_zeros = index_div.eq(0)
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
index_div_bool_zeros_count = torch.clamp(
index_div_bool_zeros_count, 0, encoder_sequence_length.max()
)
token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
index_div_bool_zeros_count *= token_num_mask
index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(
1, 1, maximum_length
)
ones = torch.ones_like(index_div_bool_zeros_count_tile)
zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
ones = torch.cumsum(ones, dim=2)
cond = index_div_bool_zeros_count_tile == ones
index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
predictor_mask = (
(~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max()))
.type(int_type)
.to(encoder_sequence_length.device)
)
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
predictor_alignments = index_div_bool_zeros_count_tile_out
predictor_alignments_length = predictor_alignments.sum(-1).type(
encoder_sequence_length.dtype
)
return predictor_alignments.detach(), predictor_alignments_length.detach()
@tables.register("predictor_classes", "CifPredictorV2")
class CifPredictorV2(torch.nn.Module):
def __init__(
self,
idim,
l_order,
r_order,
threshold=1.0,
dropout=0.1,
smooth_factor=1.0,
noise_threshold=0,
tail_threshold=0.0,
tf2torch_tensor_name_prefix_torch="predictor",
tf2torch_tensor_name_prefix_tf="seq2seq/cif",
tail_mask=True,
):
super().__init__()
self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1)
self.cif_output = torch.nn.Linear(idim, 1)
self.dropout = torch.nn.Dropout(p=dropout)
self.threshold = threshold
self.smooth_factor = smooth_factor
self.noise_threshold = noise_threshold
self.tail_threshold = tail_threshold
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
self.tail_mask = tail_mask
def forward(
self,
hidden,
target_label=None,
mask=None,
ignore_id=-1,
mask_chunk_predictor=None,
target_label_length=None,
):
with autocast(False):
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
output = torch.relu(self.cif_conv1d(queries))
output = output.transpose(1, 2)
output = self.cif_output(output)
alphas = torch.sigmoid(output)
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
if mask is not None:
mask = mask.transpose(-1, -2).float()
alphas = alphas * mask
if mask_chunk_predictor is not None:
alphas = alphas * mask_chunk_predictor
alphas = alphas.squeeze(-1)
mask = mask.squeeze(-1)
if target_label_length is not None:
target_length = target_label_length.squeeze(-1)
elif target_label is not None:
target_length = (target_label != ignore_id).float().sum(-1)
else:
target_length = None
token_num = alphas.sum(-1)
if target_length is not None:
alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
elif self.tail_threshold > 0.0:
if self.tail_mask:
hidden, alphas, token_num = self.tail_process_fn(
hidden, alphas, token_num, mask=mask
)
else:
hidden, alphas, token_num = self.tail_process_fn(
hidden, alphas, token_num, mask=None
)
acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
if target_length is None and self.tail_threshold > 0.0:
token_num_int = torch.max(token_num).type(torch.int32).item()
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
return acoustic_embeds, token_num, alphas, cif_peak
def forward_chunk(self, hidden, cache=None, **kwargs):
is_final = kwargs.get("is_final", False)
batch_size, len_time, hidden_size = hidden.shape
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
output = torch.relu(self.cif_conv1d(queries))
output = output.transpose(1, 2)
output = self.cif_output(output)
alphas = torch.sigmoid(output)
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
alphas = alphas.squeeze(-1)
token_length = []
list_fires = []
list_frames = []
cache_alphas = []
cache_hiddens = []
if cache is not None and "chunk_size" in cache:
alphas[:, : cache["chunk_size"][0]] = 0.0
if not is_final:
alphas[:, sum(cache["chunk_size"][:2]) :] = 0.0
if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device)
cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device)
hidden = torch.cat((cache["cif_hidden"], hidden), dim=1)
alphas = torch.cat((cache["cif_alphas"], alphas), dim=1)
if cache is not None and is_final:
tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device)
tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device)
tail_alphas = torch.tile(tail_alphas, (batch_size, 1))
hidden = torch.cat((hidden, tail_hidden), dim=1)
alphas = torch.cat((alphas, tail_alphas), dim=1)
len_time = alphas.shape[1]
for b in range(batch_size):
integrate = 0.0
frames = torch.zeros((hidden_size), device=hidden.device)
list_frame = []
list_fire = []
for t in range(len_time):
alpha = alphas[b][t]
if alpha + integrate < self.threshold:
integrate += alpha
list_fire.append(integrate)
frames += alpha * hidden[b][t]
else:
frames += (self.threshold - integrate) * hidden[b][t]
list_frame.append(frames)
integrate += alpha
list_fire.append(integrate)
integrate -= self.threshold
frames = integrate * hidden[b][t]
cache_alphas.append(integrate)
if integrate > 0.0:
cache_hiddens.append(frames / integrate)
else:
cache_hiddens.append(frames)
token_length.append(torch.tensor(len(list_frame), device=alphas.device))
list_fires.append(list_fire)
list_frames.append(list_frame)
cache["cif_alphas"] = torch.stack(cache_alphas, axis=0)
cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0)
cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0)
cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0)
max_token_len = max(token_length)
if max_token_len == 0:
return hidden, torch.stack(token_length, 0), None, None
list_ls = []
for b in range(batch_size):
pad_frames = torch.zeros(
(max_token_len - token_length[b], hidden_size), device=alphas.device
)
if token_length[b] == 0:
list_ls.append(pad_frames)
else:
list_frames[b] = torch.stack(list_frames[b])
list_ls.append(torch.cat((list_frames[b], pad_frames), dim=0))
cache["cif_alphas"] = torch.stack(cache_alphas, axis=0)
cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0)
cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0)
cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0)
return torch.stack(list_ls, 0), torch.stack(token_length, 0), None, None
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
b, t, d = hidden.size()
tail_threshold = self.tail_threshold
if mask is not None:
zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
ones_t = torch.ones_like(zeros_t)
mask_1 = torch.cat([mask, zeros_t], dim=1)
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
alphas = torch.cat([alphas, zeros_t], dim=1)
alphas = torch.add(alphas, tail_threshold)
else:
tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
tail_threshold = torch.reshape(tail_threshold, (1, 1))
if b > 1:
alphas = torch.cat([alphas, tail_threshold.repeat(b, 1)], dim=1)
else:
alphas = torch.cat([alphas, tail_threshold], dim=1)
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
hidden = torch.cat([hidden, zeros], dim=1)
token_num = alphas.sum(dim=-1)
token_num_floor = torch.floor(token_num)
return hidden, alphas, token_num_floor
def gen_frame_alignments(
self, alphas: torch.Tensor = None, encoder_sequence_length: torch.Tensor = None
):
batch_size, maximum_length = alphas.size()
int_type = torch.int32
is_training = self.training
if is_training:
token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
else:
token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
max_token_num = torch.max(token_num).item()
alphas_cumsum = torch.cumsum(alphas, dim=1)
alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
index = torch.ones([batch_size, max_token_num], dtype=int_type)
index = torch.cumsum(index, dim=1)
index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
index_div_bool_zeros = index_div.eq(0)
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
index_div_bool_zeros_count = torch.clamp(
index_div_bool_zeros_count, 0, encoder_sequence_length.max()
)
token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
index_div_bool_zeros_count *= token_num_mask
index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(
1, 1, maximum_length
)
ones = torch.ones_like(index_div_bool_zeros_count_tile)
zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
ones = torch.cumsum(ones, dim=2)
cond = index_div_bool_zeros_count_tile == ones
index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
predictor_mask = (
(~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max()))
.type(int_type)
.to(encoder_sequence_length.device)
)
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
predictor_alignments = index_div_bool_zeros_count_tile_out
predictor_alignments_length = predictor_alignments.sum(-1).type(
encoder_sequence_length.dtype
)
return predictor_alignments.detach(), predictor_alignments_length.detach()
@tables.register("predictor_classes", "CifPredictorV2Export")
class CifPredictorV2Export(torch.nn.Module):
def __init__(self, model, **kwargs):
super().__init__()
self.pad = model.pad
self.cif_conv1d = model.cif_conv1d
self.cif_output = model.cif_output
self.threshold = model.threshold
self.smooth_factor = model.smooth_factor
self.noise_threshold = model.noise_threshold
self.tail_threshold = model.tail_threshold
def forward(
self,
hidden: torch.Tensor,
mask: torch.Tensor,
):
alphas, token_num = self.forward_cnn(hidden, mask)
mask = mask.transpose(-1, -2).float()
mask = mask.squeeze(-1)
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
acoustic_embeds, cif_peak = cif_v1_export(hidden, alphas, self.threshold)
return acoustic_embeds, token_num, alphas, cif_peak
def forward_cnn(
self,
hidden: torch.Tensor,
mask: torch.Tensor,
):
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
output = torch.relu(self.cif_conv1d(queries))
output = output.transpose(1, 2)
output = self.cif_output(output)
alphas = torch.sigmoid(output)
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
mask = mask.transpose(-1, -2).float()
alphas = alphas * mask
alphas = alphas.squeeze(-1)
token_num = alphas.sum(-1)
return alphas, token_num
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
b, t, d = hidden.size()
tail_threshold = self.tail_threshold
zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
ones_t = torch.ones_like(zeros_t)
mask_1 = torch.cat([mask, zeros_t], dim=1)
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
alphas = torch.cat([alphas, zeros_t], dim=1)
alphas = torch.add(alphas, tail_threshold)
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
hidden = torch.cat([hidden, zeros], dim=1)
token_num = alphas.sum(dim=-1)
token_num_floor = torch.floor(token_num)
return hidden, alphas, token_num_floor
@torch.jit.script
def cif_v1_export(hidden, alphas, threshold: float):
device = hidden.device
dtype = hidden.dtype
batch_size, len_time, hidden_size = hidden.size()
threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
# prefix_sum = torch.cumsum(alphas, dim=1)
prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(
torch.float32
) # cumsum precision degradation cause wrong result in extreme
prefix_sum_floor = torch.floor(prefix_sum)
dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
dislocation_prefix_sum_floor[:, 0] = 0
dislocation_diff = prefix_sum_floor - dislocation_prefix_sum_floor
fire_idxs = dislocation_diff > 0
fires[fire_idxs] = 1
fires = fires + prefix_sum - prefix_sum_floor
# prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).repeat((1, 1, hidden_size)) * hidden, dim=1)
frames = prefix_sum_hidden[fire_idxs]
shift_frames = torch.roll(frames, 1, dims=0)
batch_len = fire_idxs.sum(1)
batch_idxs = torch.cumsum(batch_len, dim=0)
shift_batch_idxs = torch.roll(batch_idxs, 1, dims=0)
shift_batch_idxs[0] = 0
shift_frames[shift_batch_idxs] = 0
remains = fires - torch.floor(fires)
# remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
remain_frames = remains[fire_idxs].unsqueeze(-1).repeat((1, hidden_size)) * hidden[fire_idxs]
shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
shift_remain_frames[shift_batch_idxs] = 0
frames = frames - shift_frames + shift_remain_frames - remain_frames
# max_label_len = batch_len.max()
max_label_len = alphas.sum(dim=-1)
max_label_len = torch.floor(max_label_len).max().to(dtype=torch.int64)
# frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
frame_fires_idxs = indices < batch_len.unsqueeze(1)
frame_fires[frame_fires_idxs] = frames
return frame_fires, fires
@torch.jit.script
def cif_export(hidden, alphas, threshold: float):
batch_size, len_time, hidden_size = hidden.size()
threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
# loop varss
integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=hidden.device)
frame = torch.zeros([batch_size, hidden_size], dtype=hidden.dtype, device=hidden.device)
# intermediate vars along time
list_fires = []
list_frames = []
for t in range(len_time):
alpha = alphas[:, t]
distribution_completion = (
torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device) - integrate
)
integrate += alpha
list_fires.append(integrate)
fire_place = integrate >= threshold
integrate = torch.where(
fire_place,
integrate - torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device),
integrate,
)
cur = torch.where(fire_place, distribution_completion, alpha)
remainds = alpha - cur
frame += cur[:, None] * hidden[:, t, :]
list_frames.append(frame)
frame = torch.where(
fire_place[:, None].repeat(1, hidden_size), remainds[:, None] * hidden[:, t, :], frame
)
fires = torch.stack(list_fires, 1)
frames = torch.stack(list_frames, 1)
fire_idxs = fires >= threshold
frame_fires = torch.zeros_like(hidden)
max_label_len = frames[0, fire_idxs[0]].size(0)
for b in range(batch_size):
frame_fire = frames[b, fire_idxs[b]]
frame_len = frame_fire.size(0)
frame_fires[b, :frame_len, :] = frame_fire
if frame_len >= max_label_len:
max_label_len = frame_len
frame_fires = frame_fires[:, :max_label_len, :]
return frame_fires, fires
class mae_loss(torch.nn.Module):
def __init__(self, normalize_length=False):
super(mae_loss, self).__init__()
self.normalize_length = normalize_length
self.criterion = torch.nn.L1Loss(reduction="sum")
def forward(self, token_length, pre_token_length):
loss_token_normalizer = token_length.size(0)
if self.normalize_length:
loss_token_normalizer = token_length.sum().type(torch.float32)
loss = self.criterion(token_length, pre_token_length)
loss = loss / loss_token_normalizer
return loss
def cif(hidden, alphas, threshold):
batch_size, len_time, hidden_size = hidden.size()
# loop varss
integrate = torch.zeros([batch_size], device=hidden.device)
frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
# intermediate vars along time
list_fires = []
list_frames = []
for t in range(len_time):
alpha = alphas[:, t]
distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate
integrate += alpha
list_fires.append(integrate)
fire_place = integrate >= threshold
integrate = torch.where(
fire_place, integrate - torch.ones([batch_size], device=hidden.device), integrate
)
cur = torch.where(fire_place, distribution_completion, alpha)
remainds = alpha - cur
frame += cur[:, None] * hidden[:, t, :]
list_frames.append(frame)
frame = torch.where(
fire_place[:, None].repeat(1, hidden_size), remainds[:, None] * hidden[:, t, :], frame
)
fires = torch.stack(list_fires, 1)
frames = torch.stack(list_frames, 1)
list_ls = []
len_labels = torch.round(alphas.sum(-1)).int()
max_label_len = len_labels.max()
for b in range(batch_size):
fire = fires[b, :]
l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
pad_l = torch.zeros([max_label_len - l.size(0), hidden_size], device=hidden.device)
list_ls.append(torch.cat([l, pad_l], 0))
return torch.stack(list_ls, 0), fires
def cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=False):
batch_size, len_time = alphas.size()
device = alphas.device
dtype = alphas.dtype
threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
if torch.cuda.get_device_name() == "Iluvatar BI-V150":
prefix_sum = torch.cumsum(alphas, dim=1)
else:
prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(
torch.float32
) # cumsum precision degradation cause wrong result in extreme
prefix_sum_floor = torch.floor(prefix_sum)
dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
dislocation_prefix_sum_floor[:, 0] = 0
dislocation_diff = prefix_sum_floor - dislocation_prefix_sum_floor
fire_idxs = dislocation_diff > 0
fires[fire_idxs] = 1
fires = fires + prefix_sum - prefix_sum_floor
if return_fire_idxs:
return fires, fire_idxs
return fires
def cif_v1(hidden, alphas, threshold):
fires, fire_idxs = cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=True)
device = hidden.device
dtype = hidden.dtype
batch_size, len_time, hidden_size = hidden.size()
# frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
# prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).repeat((1, 1, hidden_size)) * hidden, dim=1)
frames = prefix_sum_hidden[fire_idxs]
shift_frames = torch.roll(frames, 1, dims=0)
batch_len = fire_idxs.sum(1)
batch_idxs = torch.cumsum(batch_len, dim=0)
shift_batch_idxs = torch.roll(batch_idxs, 1, dims=0)
shift_batch_idxs[0] = 0
shift_frames[shift_batch_idxs] = 0
remains = fires - torch.floor(fires)
# remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
remain_frames = remains[fire_idxs].unsqueeze(-1).repeat((1, hidden_size)) * hidden[fire_idxs]
shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
shift_remain_frames[shift_batch_idxs] = 0
frames = frames - shift_frames + shift_remain_frames - remain_frames
# max_label_len = batch_len.max()
max_label_len = (
torch.round(alphas.sum(-1)).int().max()
) # torch.round to calculate the max length
# frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
frame_fires_idxs = indices < batch_len.unsqueeze(1)
frame_fires[frame_fires_idxs] = frames
return frame_fires, fires
def cif_wo_hidden(alphas, threshold):
batch_size, len_time = alphas.size()
# loop varss
integrate = torch.zeros([batch_size], device=alphas.device)
# intermediate vars along time
list_fires = []
for t in range(len_time):
alpha = alphas[:, t]
integrate += alpha
list_fires.append(integrate)
fire_place = integrate >= threshold
integrate = torch.where(
fire_place,
integrate - torch.ones([batch_size], device=alphas.device) * threshold,
integrate,
)
fires = torch.stack(list_fires, 1)
return fires

View File

@@ -0,0 +1,746 @@
import logging
import os
import random
import re
import string
import time
import traceback
from typing import Union
import torch
import torch.nn as nn
from funasr.metrics.compute_acc import compute_accuracy
from funasr.register import tables
from funasr.train_utils.device_funcs import force_gatherable, to_device
from funasr.utils.datadir_writer import DatadirWriter
from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
from transformers import AutoConfig, AutoModelForCausalLM
from funasr.models.fun_asr_nano.ctc import CTC
from funasr.models.fun_asr_nano.tools.utils import forced_align
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
@tables.register("model_classes", "FunASRNano")
class FunASRNano(nn.Module):
def __init__(
self,
audio_encoder: str = None,
audio_encoder_conf: dict = None,
audio_adaptor: str = None,
audio_adaptor_conf: dict = None,
llm: str = None,
llm_conf: dict = None,
input_size: int = 80,
length_normalized_loss: bool = False,
**kwargs,
):
super().__init__()
# audio encoder
hub = audio_encoder_conf.get("hub", None)
self.audio_encoder_activation_checkpoint = audio_encoder_conf.get(
"activation_checkpoint", False
)
if hub == "ms":
from funasr import AutoModel
model = AutoModel(model=audio_encoder, model_revision="master")
audio_encoder_output_size = (
model.model.encoder_output_size
if hasattr(model.model, "encoder_output_size")
else -1
)
audio_encoder = (
model.model.model.encoder if hasattr(model.model, "model") else model.model.encoder
)
else:
encoder_class = tables.encoder_classes.get(audio_encoder)
audio_encoder = encoder_class(input_size=input_size, **audio_encoder_conf)
audio_encoder_output_size = audio_encoder.output_size()
freeze = audio_encoder_conf.get("freeze", True)
if freeze:
for _, param in audio_encoder.named_parameters():
param.requires_grad = False
audio_encoder.eval()
self.audio_encoder = audio_encoder
# llm
self.llm = None
init_param_path = llm_conf.get("init_param_path", None)
llm_dim = None
llm_load_kwargs = llm_conf.get("load_kwargs", {})
config = AutoConfig.from_pretrained(init_param_path)
model = AutoModelForCausalLM.from_config(config, **llm_load_kwargs)
freeze = llm_conf.get("freeze", True)
if freeze:
for _, param in model.named_parameters():
param.requires_grad = False
model.eval()
if llm_conf.get("activation_checkpoint", False):
model.gradient_checkpointing_enable()
self.llm_dtype = llm_conf.get("llm_dtype", "fp32")
self.llm = model.to(dtype_map[self.llm_dtype])
llm_dim = model.get_input_embeddings().weight.shape[-1]
# adaptor
adaptor_class = tables.adaptor_classes.get(audio_adaptor)
if audio_encoder_output_size > 0:
audio_adaptor_conf["encoder_dim"] = audio_encoder_output_size
audio_adaptor_conf["llm_dim"] = (
llm_dim if llm_dim is not None else audio_adaptor_conf["llm_dim"]
)
audio_adaptor = adaptor_class(**audio_adaptor_conf)
freeze = audio_adaptor_conf.get("freeze", False)
if freeze:
for _, param in audio_adaptor.named_parameters():
param.requires_grad = False
audio_adaptor.eval()
self.audio_adaptor = audio_adaptor
self.use_low_frame_rate = audio_adaptor_conf.get("use_low_frame_rate", False)
# ctc decoder
self.ctc_decoder = None
# TODO: fix table name
ctc_decoder_class = tables.adaptor_classes.get(kwargs.get("ctc_decoder", None))
if ctc_decoder_class is not None:
ctc_tokenizer = (
kwargs.get("ctc_tokenizer", None)
if "ctc_tokenizer" in kwargs
else kwargs["dataset_conf"]["ctc_tokenizer"]
)
ctc_tokenizer_conf = (
kwargs.get("ctc_tokenizer_conf", None)
if "ctc_tokenizer_conf" in kwargs
else kwargs["dataset_conf"]["ctc_tokenizer_conf"]
)
if ctc_tokenizer is not None and ctc_tokenizer_conf is not None:
ctc_tokenizer_class = tables.tokenizer_classes.get(ctc_tokenizer)
ctc_tokenizer = ctc_tokenizer_class(**ctc_tokenizer_conf)
self.ctc_tokenizer = ctc_tokenizer
assert ctc_tokenizer is not None, f"ctc_tokenizer must be set"
ctc_vocab_size = kwargs.get("ctc_vocab_size", 60515)
ctc_decoder_conf = kwargs.get("ctc_decoder_conf", {})
if audio_encoder_output_size > 0:
ctc_decoder_conf["encoder_dim"] = audio_encoder_output_size
self.ctc_decoder = ctc_decoder_class(**ctc_decoder_conf)
init_param_path = ctc_decoder_conf.get("init_param_path", None)
if init_param_path is not None:
src_state = torch.load(init_param_path, map_location="cpu")
flag = self.ctc_decoder.load_state_dict(src_state, strict=False)
logging.info(f"Loading ctc_decoder ckpt: {init_param_path}, status: {flag}")
freeze = ctc_decoder_conf.get("freeze", False)
if freeze:
for _, param in self.ctc_decoder.named_parameters():
param.requires_grad = False
self.ctc_decoder.eval()
ctc_conf = kwargs.get("ctc_conf", {})
self.blank_id = ctc_conf.get("blank_id", ctc_vocab_size - 1)
self.ctc_weight = kwargs.get("ctc_weight", 0.3)
self.ctc = CTC(
odim=ctc_vocab_size,
encoder_output_size=audio_encoder_output_size,
blank_id=self.blank_id,
**ctc_conf,
)
self.detach_ctc_decoder = kwargs.get("detach_ctc_decoder", True)
self.error_calculator = None
self.length_normalized_loss = length_normalized_loss
rank = int(os.environ.get("RANK", 0))
logging.info(f"rank: {rank}, model is builded.")
def forward(
self,
speech: torch.Tensor = None,
speech_lengths: torch.Tensor = None,
input_ids: torch.Tensor = None,
attention_mask: torch.Tensor = None,
labels_ids: torch.Tensor = None,
fbank_beg: torch.Tensor = None,
fbank_mask: torch.Tensor = None,
**kwargs,
):
batch_size, token_num = input_ids.shape
stats = {}
input_ids[input_ids < 0] = 0
inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
if speech is not None:
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size_speech, frames, _ = speech.shape
# audio encoder
if self.audio_encoder_activation_checkpoint:
from torch.utils.checkpoint import checkpoint
encoder_out, encoder_out_lens = checkpoint(
self.encode, speech, speech_lengths, use_reentrant=False
)
else:
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# audio_adaptor
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
batch_size, token_num, dims = inputs_embeds.shape
fake_token_len = kwargs.get("fake_token_len")
fake_token_len[fake_token_len < 0] = 0
fbank_beg[fbank_beg < 0] = 0
speech_idx = 0
for batch_idx in range(batch_size):
for turn_id in range(fbank_beg.shape[1]):
fbank_beg_idx = fbank_beg[batch_idx, turn_id].item()
if fbank_beg_idx > 0:
speech_token_len = fake_token_len[batch_idx, turn_id]
speech_token = encoder_out[speech_idx, :speech_token_len, :]
try:
inputs_embeds[
batch_idx,
fbank_beg_idx : fbank_beg_idx + speech_token_len,
:,
] = speech_token
except Exception as e:
logging.error(f"{str(e)}, {traceback.format_exc()}")
logging.info(
f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}"
)
speech_token_len = encoder_out_lens[speech_idx].item()
speech_token = encoder_out[speech_idx, :speech_token_len, :]
inputs_embeds[
batch_idx,
fbank_beg_idx : fbank_beg_idx + speech_token_len,
:,
] = speech_token
speech_idx += 1
stats["batch_size_speech"] = batch_size_speech
stats["batch_size_x_frames"] = frames * batch_size_speech
stats["batch_size_real_frames"] = speech_lengths.sum().item()
stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
device_type = next(self.parameters()).device.type
with torch.autocast(
device_type=device_type if device_type in ["cuda", "xpu", "mps"] else "cpu",
enabled=True if self.llm_dtype != "fp32" else False,
dtype=dtype_map[self.llm_dtype],
):
labels_ids[labels_ids == -1] = -100
attention_mask[attention_mask < 0] = 0
model_outputs = self.llm(
inputs_embeds=inputs_embeds.to(dtype_map[self.llm_dtype]),
attention_mask=attention_mask,
labels=labels_ids,
)
loss = model_outputs.loss
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
stats["acc"] = acc_att
stats["loss"] = torch.clone(loss.detach())
stats["batch_size"] = batch_size
stats["batch_size_x_tokens"] = token_num * batch_size
stats["batch_size_real_tokens"] = attention_mask.sum().item()
stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
dialog_turns = (fbank_beg > 0).sum(-1)
dialog_turns_max = torch.max(dialog_turns).int().item()
dialog_turns_avg = dialog_turns.sum().item() / batch_size
stats["dialog_turns_max"] = dialog_turns_max
stats["dialog_turns_avg"] = dialog_turns_avg
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((labels_ids > 0 + 1).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def forward_export(self, speech, speech_lengths, **kwargs):
x, olens = self.audio_encoder(speech, speech_lengths)
encoder_out, encoder_out_lens = self.audio_adaptor(x, olens)
return encoder_out, encoder_out_lens
def encode(self, speech, speech_lengths):
# audio encoder
encoder_out, encoder_out_lens = self.audio_encoder(speech, speech_lengths)
return encoder_out, encoder_out_lens
def data_template(self, data):
system, user, assistant = [], [], []
for i, item in enumerate(data):
role = item["role"]
content = item["content"]
if role == "system":
system.append(content)
elif role == "user":
if "audio" in item:
audio = item["audio"]
content = [content, audio]
user.append(content)
elif role == "assistant":
assistant.append(content)
system = system * len(user)
contents = {
"system": system,
"user": user,
"assistant": assistant,
}
return contents
def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs):
system = contents["system"]
user = contents["user"]
assistant = contents["assistant"]
pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)")
do_think = True
sys_prompt = True
if "dataset_conf" in kwargs:
do_think = kwargs["dataset_conf"].get("do_think", True)
sys_prompt = kwargs["dataset_conf"].get("sys_prompt", True)
input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg, fake_token_len = (
[],
[],
[],
[],
[],
[],
[],
)
input_source_ids = []
for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)):
if i >= kwargs.get("multiturn_num_max", 5):
break
if len(input_ids) > kwargs.get("max_token_length", 1500):
break
if isinstance(user_prompt, (list, tuple)):
user_prompt, audio = user_prompt
if i == 0:
if kwargs.get("infer_with_assistant_input", False):
source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}"
if not sys_prompt:
source_input = f"<|im_start|>user\n{user_prompt}"
else:
source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
if not sys_prompt:
source_input = (
f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
)
else:
if kwargs.get("infer_with_assistant_input", False):
source_input = f"<|im_start|>user\n{user_prompt}"
else:
source_input = (
f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
)
if not do_think:
source_input += "<think>\n\n</think>\n\n"
if kwargs.get("prev_text", None) is not None:
source_input += kwargs["prev_text"]
splits = pattern.split(source_input)
source_ids = []
fbank_mask_i = []
fake_token_len_i = 0
fbank_beg_i = -1
speech, speech_lengths = [], []
for k, sub_str in enumerate(splits):
if not sub_str.startswith("<|startofspeech|>"):
sub_token = tokenizer.encode(sub_str)
source_ids += sub_token
fbank_mask_i += [0] * len(sub_token)
else:
sub_str = sub_str.replace("<|startofspeech|>", "").replace(
"<|endofspeech|>", ""
)
if sub_str.startswith("!"):
sub_str = sub_str[1:]
if sub_str.startswith("!"): # !!: audio sample point
sub_str = audio
try:
time1 = time.perf_counter()
data_src = load_audio_text_image_video(
sub_str, fs=frontend.fs, **kwargs
)
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
except Exception as e:
logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}")
speech, speech_lengths = extract_fbank(
data_src,
data_type=kwargs.get("data_type", "sound"),
frontend=frontend,
is_final=True,
) # speech: [b, T, d]
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = (
speech_lengths.sum().item()
* frontend.frame_shift
* frontend.lfr_n
/ 1000
)
if self.use_low_frame_rate:
olens = 1 + (speech_lengths[0].item() - 3 + 2 * 1) // 2
olens = 1 + (olens - 3 + 2 * 1) // 2
fake_token_len_i = (olens - 1) // 2 + 1
else:
fake_token_len_i = speech_lengths[0].item()
fake_token = [0] * fake_token_len_i
fbank_beg_i = len(source_ids)
source_ids += fake_token
fbank_mask_i += [1] * len(fake_token)
fbank_beg += [fbank_beg_i + len(input_ids)]
fake_token_len += [fake_token_len_i]
source_mask = [-100] * len(source_ids)
target_out = f"{target_out}<|im_end|>"
target_ids = tokenizer.encode(target_out)
input_source_ids = input_ids + source_ids
input_ids += source_ids + target_ids
labels += source_mask + target_ids
fbank_mask += fbank_mask_i
if len(speech) > 0:
fbank.append(speech[0, :, :])
fbank_lens.append(speech_lengths)
input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length]
attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length]
fbank_mask = torch.tensor(fbank_mask, dtype=torch.float32)
fbank_beg = torch.tensor(fbank_beg, dtype=torch.int32)
fake_token_len = torch.tensor(fake_token_len, dtype=torch.int32)
source_ids = torch.tensor(input_source_ids, dtype=torch.int64)
target_ids = torch.tensor(target_ids, dtype=torch.int64)
if len(fbank) > 0:
speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0)
speech_lengths = torch.nn.utils.rnn.pad_sequence(
fbank_lens, batch_first=True, padding_value=-1
)
else:
speech = []
speech_lengths = []
output = {
"speech": speech,
"speech_lengths": speech_lengths,
"fbank_mask": fbank_mask[None, :],
"fbank_beg": fbank_beg[None,],
"fake_token_len": fake_token_len[None, :],
"input_ids": input_ids[None,],
"attention_mask": attention_mask[None,],
"labels_ids": labels,
"source_ids": source_ids[None, :],
"target_ids": target_ids[None, :],
}
return output
def inference_prepare(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
meta_data = {}
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
contents = self.data_template(data_in[0])
output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs)
batch = to_device(output, kwargs["device"])
# audio encoder
speech = batch["speech"]
if len(speech) > 0:
if "audio_embedding" in kwargs and "audio_embedding_lens" in kwargs:
encoder_out = kwargs["audio_embedding"]
encoder_out_lens = kwargs["audio_embedding_lens"]
else:
speech_lengths = batch["speech_lengths"][:, 0]
# fp16
if kwargs.get("fp16", False):
speech = speech.to(torch.float16)
elif kwargs.get("bf16", False):
speech = speech.to(torch.bfloat16)
# audio encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# audio_adaptor
adaptor_out, adaptor_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
meta_data["encoder_out"] = encoder_out
meta_data["encoder_out_lens"] = encoder_out_lens
meta_data["audio_adaptor_out"] = adaptor_out
meta_data["audio_adaptor_out_lens"] = adaptor_out_lens
input_ids = batch["input_ids"]
source_ids = batch["source_ids"]
fbank_beg = batch["fbank_beg"]
fake_token_len = batch["fake_token_len"]
if not kwargs.get("teacherforcing", False):
input_ids = source_ids
input_ids[input_ids < 0] = 0
inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
batch_size, token_num, dims = inputs_embeds.shape
fake_token_len[fake_token_len < 0] = 0
fbank_beg[fbank_beg < 0] = 0
speech_idx = 0
for batch_idx in range(batch_size):
for turn_id in range(fbank_beg.shape[1]):
fbank_beg_idx = fbank_beg[batch_idx, turn_id].item()
if fbank_beg_idx > 0:
speech_token_len = fake_token_len[batch_idx, turn_id]
speech_token = adaptor_out[speech_idx, :speech_token_len, :]
try:
inputs_embeds[
batch_idx,
fbank_beg_idx : fbank_beg_idx + speech_token_len,
:,
] = speech_token
except Exception as e:
#
logging.error(f"{str(e)}, {traceback.format_exc()}")
logging.info(
f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, adaptor_out: {adaptor_out.shape}, adaptor_out_lens: {adaptor_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}"
)
speech_token_len = adaptor_out_lens[speech_idx].item()
speech_token = adaptor_out[speech_idx, :speech_token_len, :]
inputs_embeds[
batch_idx,
fbank_beg_idx : fbank_beg_idx + speech_token_len,
:,
] = speech_token
speech_idx += 1
return inputs_embeds, contents, batch, source_ids, meta_data
def get_prompt(self, hotwords: list[str], language: str = None, itn: bool = True):
if len(hotwords) > 0:
hotwords = ", ".join(hotwords)
prompt = f"请结合上下文信息,更加准确地完成语音转写任务。如果没有相关信息,我们会留空。\n\n\n**上下文信息:**\n\n\n"
prompt += f"热词列表:[{hotwords}]\n"
else:
prompt = ""
if language is None:
prompt += "语音转写"
else:
prompt += f"语音转写成{language}"
if not itn:
prompt += ",不进行文本规整"
return prompt + ""
def generate_chatml(self, prompt: str, data: Union[str, torch.Tensor]):
if isinstance(data, str):
return [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": f"{prompt}<|startofspeech|>!{data}<|endofspeech|>"},
{"role": "assistant", "content": "null"},
]
elif isinstance(data, torch.Tensor):
return [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": f"{prompt}<|startofspeech|>!!<|endofspeech|>",
"audio": data,
},
{"role": "assistant", "content": "null"},
]
def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
prompt = self.get_prompt(
kwargs.get("hotwords", []), kwargs.get("language", None), kwargs.get("itn", True)
)
data_in = [self.generate_chatml(prompt, data) for data in data_in]
if key is None:
key = []
for _ in data_in:
chars = string.ascii_letters + string.digits
key.append("rand_key_" + "".join(random.choice(chars) for _ in range(13)))
return self.inference_llm(
data_in,
data_lengths=data_lengths,
key=key,
tokenizer=tokenizer,
frontend=frontend,
**kwargs,
)
def inference_llm(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
data_in, data_lengths, key, tokenizer, frontend, **kwargs
)
ctc_results = []
if self.ctc_decoder is not None:
encoder_out = meta_data["encoder_out"]
encoder_out_lens = meta_data["encoder_out_lens"]
decoder_out, decoder_out_lens = self.ctc_decoder(encoder_out, encoder_out_lens)
ctc_logits = self.ctc.log_softmax(decoder_out)
b, n, d = encoder_out.size()
if isinstance(key[0], (list, tuple)):
key = key[0]
if len(key) < b:
key = key * b
for i in range(b):
x = ctc_logits[i, : encoder_out_lens[i].item(), :]
yseq = x.argmax(dim=-1)
yseq = torch.unique_consecutive(yseq, dim=-1)
mask = yseq != self.blank_id
token_int = yseq[mask].tolist()
# Change integer-ids to tokens
text = self.ctc_tokenizer.decode(token_int)
ctc_results.append({"key": key[i], "text": text, "ctc_logits": x})
llm_dtype = kwargs.get("llm_dtype", "fp32")
if llm_dtype == "fp32":
llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype
llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype
device_type = torch.device(kwargs.get("device", "cuda")).type
with torch.autocast(
device_type=device_type if device_type in ["cuda", "xpu", "mps"] else "cpu",
enabled=True if llm_dtype != "fp32" else False,
dtype=dtype_map[llm_dtype],
):
label = contents["assistant"][-1]
self.llm = self.llm.to(dtype_map[llm_dtype])
inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
llm_kwargs = kwargs.get("llm_kwargs", {})
if not kwargs.get("teacherforcing", False):
attention_mask = batch.get("attention_mask", None)
generated_ids = self.llm.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
max_new_tokens=kwargs.get("max_length", 512),
pad_token_id=self.llm.config.pad_token_id or self.llm.config.eos_token_id,
**llm_kwargs,
)
response = tokenizer.batch_decode(
generated_ids,
skip_special_tokens=kwargs.get("skip_special_tokens", True),
)[0]
loss = None
else:
labels_ids = batch["labels_ids"]
labels_ids[labels_ids == -1] = -100
attention_mask = batch.get("attention_mask", None)
model_outputs = self.llm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels_ids,
pad_token_id=self.llm.config.pad_token_id or self.llm.config.eos_token_id,
**llm_kwargs,
)
preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :]
response = tokenizer.batch_decode(
preds,
add_special_tokens=False,
skip_special_tokens=kwargs.get("skip_special_tokens", True),
)[0]
loss = model_outputs.loss.item()
response = kwargs.get("prev_text", "") + response
ibest_writer = None
if kwargs.get("output_dir") is not None:
if not hasattr(self, "writer"):
self.writer = DatadirWriter(kwargs.get("output_dir"))
ibest_writer = self.writer[f"{0 + 1}best_recog"]
results = []
response_clean = re.sub(r"[^\w\s\u3000\u4e00-\u9fff]+", "", response)
result_i = {
"key": key[0],
"text": re.sub(r"\s+", " ", response.replace("/sil", " ")),
"text_tn": response_clean,
"label": label,
}
if loss is not None:
result_i["loss"] = loss
results.append(result_i)
for ctc_result, result in zip(ctc_results, results):
result["ctc_text"] = ctc_result["text"].replace("<|nospeech|>", "")
target_ids = torch.tensor(
self.ctc_tokenizer.encode(result["ctc_text"]), dtype=torch.int64
)
result["ctc_timestamps"] = forced_align(
ctc_result["ctc_logits"], target_ids, self.blank_id
)
target_ids = torch.tensor(self.ctc_tokenizer.encode(result["text"]), dtype=torch.int64)
result["timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id)
for timestamps in [result["timestamps"], result["ctc_timestamps"]]:
for timestamp in timestamps:
timestamp["token"] = self.ctc_tokenizer.decode([timestamp["token"]])
timestamp["start_time"] = timestamp["start_time"] * 6 * 10 / 1000
timestamp["end_time"] = timestamp["end_time"] * 6 * 10 / 1000
if ibest_writer is not None:
ibest_writer["text"][key[0]] = response.replace("\n", " ")
ibest_writer["label"][key[0]] = label.replace("\n", " ")
ibest_writer["text_tn"][key[0]] = response_clean
return results, meta_data
@staticmethod
def from_pretrained(model: str = None, **kwargs):
from funasr import AutoModel
model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs)
return model, kwargs

14
funasr/requirements.txt Normal file
View File

@@ -0,0 +1,14 @@
requests
wheel
websocket-client
pydantic>=2.0.0
numpy<2.0
PYYaml
Levenshtein
ruamel.yaml
nltk==3.7
pynini==2.1.6
soundfile
fastapi
uvicorn
python-multipart

BIN
funasr/warmup.wav Normal file

Binary file not shown.