initial commit

This commit is contained in:
2026-04-08 06:41:00 +00:00
commit 1385a6f46b
23 changed files with 2831 additions and 0 deletions

View File

@@ -0,0 +1,27 @@
FROM corex:4.3.8
WORKDIR /root
RUN set -eux; \
# 1) 把 aliyun 源替换成官方源(避免 403
sed -i -E 's|http://mirrors\.aliyun\.com/ubuntu|http://archive.ubuntu.com/ubuntu|g' /etc/apt/sources.list; \
sed -i -E 's|http://mirrors\.aliyun\.com/ubuntu|http://archive.ubuntu.com/ubuntu|g' /etc/apt/sources.list.d/*.list 2>/dev/null || true; \
\
# 2) 更新并安装
apt-get update; \
apt-get install -y --no-install-recommends vim net-tools ca-certificates; \
rm -rf /var/lib/apt/lists/*
ADD . /root/
COPY requirements.txt /root
RUN pip install -r /root/requirements.txt -i https://nexus.4pd.io/repository/pypi-all/simple --extra-index-url https://mirror.sjtu.edu.cn/pypi/web/simple
RUN pip install funasr==1.3.1 'transformers>=4.51.3' openai-whisper \
-i https://nexus.4pd.io/repository/pypi-all/simple --extra-index-url https://mirror.sjtu.edu.cn/pypi/web/simple
# Patch files
COPY ./replaced_files/bi_v150/cif_predictor.py /usr/local/lib/python3.10/site-packages/funasr/models/paraformer/
COPY ./replaced_files/funasr_nano_model.py /usr/local/lib/python3.10/site-packages/funasr/models/fun_asr_nano/model.py
ENTRYPOINT ["python3"]
CMD ["main.py"]

28
funasr/README.md Normal file
View File

@@ -0,0 +1,28 @@
# 天数智芯 天垓150 ASRFunASR架构
## 镜像构造
```shell
docker build -f ./Dockerfile.funasr-bi150 -t <your_image> .
```
其中,基础镜像 corex:4.3.8 通过联系天数智芯智铠100厂商技术支持可获取
## 使用说明
### 使用 FastAPI 启动ASR服务
例如:
```shell
docker run -dit -v /usr/src:/usr/src -v /lib/modules:/lib/modules --device=/dev/iluvatar0:/dev/iluvatar0 \
-v /mnt/contest_ceph/leaderboard/modelHubXC/iic/SenseVoiceSmall:/model \
--network=host <your_image> \
main.py --model_dir /model --model_type sensevoice --use_gpu --port 1111
```
具体参数代码设定可参考代码文件
### 测试ASR服务
项目根路径`sample_data`目录下附带上了中文的测试音频和附带内容
```shell
curl -X POST http://localhost:1111/transduce \
-F "audio=@../sample_data/lei-jun-test.wav" \
-F "lang=zh"
```

271
funasr/fastapi_funasr.py Normal file
View File

@@ -0,0 +1,271 @@
import os
import time
import argparse
import torchaudio
import torch
import traceback
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Form
import uuid
import uvicorn
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
# from funasr.models.fun_asr_nano.model import FunASRNano
os.makedirs("./input", exist_ok=True)
status = "Running"
model = None
device = ""
app = FastAPI()
CUSTOM_DEVICE = os.getenv("CUSTOM_DEVICE", "")
if CUSTOM_DEVICE.startswith("mlu"):
import torch_mlu
elif CUSTOM_DEVICE.startswith("ascend"):
import torch_npu
elif CUSTOM_DEVICE.startswith("pt"):
import torch_dipu
def make_all_dense(module: torch.nn.Module):
for name, param in list(module.named_parameters(recurse=True)):
if getattr(param, "is_sparse", False) and param.is_sparse:
with torch.no_grad():
dense = param.to_dense().contiguous()
parent = module
*mods, leaf = name.split(".")
for m in mods:
parent = getattr(parent, m)
setattr(parent, leaf, torch.nn.Parameter(dense, requires_grad=param.requires_grad))
# 处理 buffer如 running_mean 等)
for name, buf in list(module.named_buffers(recurse=True)):
# PyTorch 稀疏张量 layout 不是 strided
if buf.layout != torch.strided:
dense = buf.to_dense().contiguous()
parent = module
*mods, leaf = name.split(".")
for m in mods:
parent = getattr(parent, m)
parent.register_buffer(leaf, dense, persistent=True)
def split_audio(waveform, sample_rate, segment_seconds=20):
segment_samples = segment_seconds * sample_rate
segments = []
for i in range(0, waveform.shape[1], segment_samples):
segment = waveform[:, i:i + segment_samples]
if segment.shape[1] > 0:
segments.append(segment)
return segments
# def determine_model_type(model_name):
# if "sensevoice" in model_name.lower():
# return "sensevoice"
# elif "whisper" in model_name.lower():
# return "whisper"
# elif "paraformer" in model_name.lower():
# return "paraformer"
# elif "conformer" in model_name.lower():
# return "conformer"
# elif "uniasr" in model_name.lower():
# return "uni_asr"
# else:
# return "unknown"
@app.on_event("startup")
def load_model():
global status, model, device
config = app.state.config
use_gpu = config.get("use_gpu", True)
model_dir = config.get("model_dir", "/model")
model_type = config.get("model_type", "sensevoice")
warmup = config.get("warmup", False)
print(">> Startup config:")
print(" model_dir =", model_dir, flush=True)
print(" model_type =", model_type, flush=True)
print(" use_gpu =", use_gpu, flush=True)
print(" warmup =", warmup, flush=True)
device = "cpu"
if use_gpu:
if CUSTOM_DEVICE.startswith("mlu"):
device = "mlu:0"
elif CUSTOM_DEVICE.startswith("ascend"):
device = "npu:0"
else:
device = "cuda:0"
# 针对加速卡的特殊处理部分
if device == "cuda:0" and torch.cuda.get_device_name() == "Iluvatar BI-V100" and model_type == "whisper":
# 天垓100情况下的Whisper需要绕过不支持算子
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
print(f"device: {device}", flush=True)
dense_convert = False
if device == "cuda:0" and CUSTOM_DEVICE.startswith("pt") and model_type == "whisper":
dense_convert = True
if device.startswith("npu") and model_type == "whisper":
# Ascend NPU 加载whisper的部分会有Sparse部分device不匹配
dense_convert = True
print(f"dense_convert: {dense_convert}", flush=True)
if dense_convert:
model = AutoModel(
model=model_dir,
vad_model=None,
disable_update=True,
device="cpu"
)
make_all_dense(model.model)
model.model.to(dtype=torch.float32, memory_format=torch.contiguous_format)
model.model.to(device)
model.kwargs["device"] = device
else:
# 不使用VAD, punctspk模型就测试原始ASR能力
model = AutoModel(
model=model_dir,
# vad_model="fsmn-vad",
# vad_kwargs={"max_single_segment_time": 30000},
vad_model=None,
device=device,
disable_update=True
)
if device.startswith("npu") or warmup:
# Ascend NPU由于底层设计的不同初始化卡的调度比其他卡更复杂要先进行warmup
print("Start warmup...", flush=True)
res = model.generate(input="warmup.wav")
print("warmup complete.", flush=True)
status = "Success"
def test_funasr(audio_file, lang):
# 推理部分
waveform, sample_rate = torchaudio.load(audio_file)
# print(waveform.shape)
duration = waveform.shape[1] / sample_rate
segments = split_audio(waveform, sample_rate, segment_seconds=20)
generated_text = ""
processing_time = 0
model_type = app.state.config.get("model_type", "sensevoice")
if model_type == "uni_asr":
# uni_asr比较特殊设计就是处理长音频的自带VAD切分的话前20s如果几乎没有人讲话全是音乐直接会报错
# 因为可能会被切掉所有音频导致实际编解码输入为0
ts1 = time.time()
res = model.generate(
input=audio_file
)
generated_text = res[0]["text"]
ts2 = time.time()
processing_time = ts2 - ts1
else:
# 按照切分的音频依次输入
for i, segment in enumerate(segments):
segment_path = f"temp_seg_{i}.wav"
torchaudio.save(segment_path, segment, sample_rate)
ts1 = time.time()
text = None
if model_type == "sensevoice":
res = model.generate(
input=segment_path,
cache={},
language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
use_itn=True,
batch_size_s=60,
merge_vad=False,
# merge_length_s=15,
)
text = rich_transcription_postprocess(res[0]["text"])
elif model_type == "whisper":
DecodingOptions = {
"task": "transcribe",
"language": lang,
"beam_size": None,
"fp16": False,
"without_timestamps": False,
"prompt": None,
}
res = model.generate(
DecodingOptions=DecodingOptions,
input=segment_path,
batch_size_s=0,
)
text = res[0]["text"]
elif model_type == "paraformer":
res = model.generate(
input=segment_path,
batch_size_s=300
)
text = res[0]["text"]
# paraformer模型会一个字一个字输出中间夹太多空格会影响1-cer的结果
if lang == "zh":
text = text.replace(" ", "")
elif model_type == "conformer":
res = model.generate(
input=segment_path,
batch_size_s=300
)
text = res[0]["text"]
# elif model_type == "uni_asr":
# if i == 0:
# os.remove(segment_path)
# continue
# res = model.generate(
# input=segment_path
# )
# text = res[0]["text"]
else:
raise RuntimeError("unknown model type")
if text is not None:
# some models output "▁" (9601, Unicode U+2581) as separator between words, replace them with space for better readability
text = text.replace("_", " ")
text = text.replace(chr(9601), " ")
ts2 = time.time()
generated_text += text
processing_time += (ts2 - ts1)
os.remove(segment_path)
rtf = processing_time / duration
print("Text:", generated_text, flush=True)
print(f"Audio duration:\t{duration:.3f} s", flush=True)
print(f"Elapsed:\t{processing_time:.3f} s", flush=True)
print(f"RTF = {processing_time:.3f}/{duration:.3f} = {rtf:.3f}", flush=True)
return generated_text
@app.get("/health")
def health():
if status=="Running":
return {
"status":"loading model"
}
ret = {
"status": "ok" if status == "Success" else "failed",
}
return ret
@app.post("/transduce")
def transduce(
audio: UploadFile = File(...),
lang: str = Form("zh"),
background_tasks: BackgroundTasks = None
):
try:
file_path = f"./input/{uuid.uuid4()}.wav"
with open(file_path, "wb") as f:
f.write(audio.file.read())
background_tasks.add_task(os.remove, file_path)
generated_text = test_funasr(file_path, lang)
return {"generated_text": generated_text}
except Exception:
raise HTTPException(status_code=500, detail=f"Processing failed: \n{traceback.format_exc()}")
# if __name__ == "__main__":
# uvicorn.run("fastapi_funasr:app", host="0.0.0.0", port=1111, workers=1)

27
funasr/main.py Normal file
View File

@@ -0,0 +1,27 @@
import argparse
import uvicorn
from fastapi_funasr import app
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", type=str, default="/model", help="model directory")
parser.add_argument("--model_type", type=str, default="sensevoice", help="model type, e.g. sensevoice")
parser.add_argument("--use_gpu", action="store_true", default=True)
parser.add_argument("--warmup", action="store_true", help="whether do warmup when first initializing model")
parser.add_argument("--port", type=int, default=8000, help="service port")
args = parser.parse_args()
# 将参数加到 app.state 中
app.state.config = {
"model_dir": args.model_dir,
"model_type": args.model_type,
"use_gpu": args.use_gpu, # True
"warmup": args.warmup
}
uvicorn.run("fastapi_funasr:app",
host="0.0.0.0",
port=args.port,
workers=1
)

View File

@@ -0,0 +1,762 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import torch
import logging
import numpy as np
from funasr.register import tables
from funasr.train_utils.device_funcs import to_device
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from torch.cuda.amp import autocast
@tables.register("predictor_classes", "CifPredictor")
class CifPredictor(torch.nn.Module):
def __init__(
self,
idim,
l_order,
r_order,
threshold=1.0,
dropout=0.1,
smooth_factor=1.0,
noise_threshold=0,
tail_threshold=0.45,
):
super().__init__()
self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
self.cif_output = torch.nn.Linear(idim, 1)
self.dropout = torch.nn.Dropout(p=dropout)
self.threshold = threshold
self.smooth_factor = smooth_factor
self.noise_threshold = noise_threshold
self.tail_threshold = tail_threshold
def forward(
self,
hidden,
target_label=None,
mask=None,
ignore_id=-1,
mask_chunk_predictor=None,
target_label_length=None,
):
with autocast(False):
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
memory = self.cif_conv1d(queries)
output = memory + context
output = self.dropout(output)
output = output.transpose(1, 2)
output = torch.relu(output)
output = self.cif_output(output)
alphas = torch.sigmoid(output)
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
if mask is not None:
mask = mask.transpose(-1, -2).float()
alphas = alphas * mask
if mask_chunk_predictor is not None:
alphas = alphas * mask_chunk_predictor
alphas = alphas.squeeze(-1)
mask = mask.squeeze(-1)
if target_label_length is not None:
target_length = target_label_length
elif target_label is not None:
target_length = (target_label != ignore_id).float().sum(-1)
else:
target_length = None
token_num = alphas.sum(-1)
if target_length is not None:
alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
elif self.tail_threshold > 0.0:
hidden, alphas, token_num = self.tail_process_fn(
hidden, alphas, token_num, mask=mask
)
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
if target_length is None and self.tail_threshold > 0.0:
token_num_int = torch.max(token_num).type(torch.int32).item()
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
return acoustic_embeds, token_num, alphas, cif_peak
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
b, t, d = hidden.size()
tail_threshold = self.tail_threshold
if mask is not None:
zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
ones_t = torch.ones_like(zeros_t)
mask_1 = torch.cat([mask, zeros_t], dim=1)
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
alphas = torch.cat([alphas, zeros_t], dim=1)
alphas = torch.add(alphas, tail_threshold)
else:
tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
tail_threshold = torch.reshape(tail_threshold, (1, 1))
alphas = torch.cat([alphas, tail_threshold], dim=1)
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
hidden = torch.cat([hidden, zeros], dim=1)
token_num = alphas.sum(dim=-1)
token_num_floor = torch.floor(token_num)
return hidden, alphas, token_num_floor
def gen_frame_alignments(
self, alphas: torch.Tensor = None, encoder_sequence_length: torch.Tensor = None
):
batch_size, maximum_length = alphas.size()
int_type = torch.int32
is_training = self.training
if is_training:
token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
else:
token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
max_token_num = torch.max(token_num).item()
alphas_cumsum = torch.cumsum(alphas, dim=1)
alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
index = torch.ones([batch_size, max_token_num], dtype=int_type)
index = torch.cumsum(index, dim=1)
index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
index_div_bool_zeros = index_div.eq(0)
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
index_div_bool_zeros_count = torch.clamp(
index_div_bool_zeros_count, 0, encoder_sequence_length.max()
)
token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
index_div_bool_zeros_count *= token_num_mask
index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(
1, 1, maximum_length
)
ones = torch.ones_like(index_div_bool_zeros_count_tile)
zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
ones = torch.cumsum(ones, dim=2)
cond = index_div_bool_zeros_count_tile == ones
index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
predictor_mask = (
(~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max()))
.type(int_type)
.to(encoder_sequence_length.device)
)
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
predictor_alignments = index_div_bool_zeros_count_tile_out
predictor_alignments_length = predictor_alignments.sum(-1).type(
encoder_sequence_length.dtype
)
return predictor_alignments.detach(), predictor_alignments_length.detach()
@tables.register("predictor_classes", "CifPredictorV2")
class CifPredictorV2(torch.nn.Module):
def __init__(
self,
idim,
l_order,
r_order,
threshold=1.0,
dropout=0.1,
smooth_factor=1.0,
noise_threshold=0,
tail_threshold=0.0,
tf2torch_tensor_name_prefix_torch="predictor",
tf2torch_tensor_name_prefix_tf="seq2seq/cif",
tail_mask=True,
):
super().__init__()
self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1)
self.cif_output = torch.nn.Linear(idim, 1)
self.dropout = torch.nn.Dropout(p=dropout)
self.threshold = threshold
self.smooth_factor = smooth_factor
self.noise_threshold = noise_threshold
self.tail_threshold = tail_threshold
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
self.tail_mask = tail_mask
def forward(
self,
hidden,
target_label=None,
mask=None,
ignore_id=-1,
mask_chunk_predictor=None,
target_label_length=None,
):
with autocast(False):
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
output = torch.relu(self.cif_conv1d(queries))
output = output.transpose(1, 2)
output = self.cif_output(output)
alphas = torch.sigmoid(output)
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
if mask is not None:
mask = mask.transpose(-1, -2).float()
alphas = alphas * mask
if mask_chunk_predictor is not None:
alphas = alphas * mask_chunk_predictor
alphas = alphas.squeeze(-1)
mask = mask.squeeze(-1)
if target_label_length is not None:
target_length = target_label_length.squeeze(-1)
elif target_label is not None:
target_length = (target_label != ignore_id).float().sum(-1)
else:
target_length = None
token_num = alphas.sum(-1)
if target_length is not None:
alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
elif self.tail_threshold > 0.0:
if self.tail_mask:
hidden, alphas, token_num = self.tail_process_fn(
hidden, alphas, token_num, mask=mask
)
else:
hidden, alphas, token_num = self.tail_process_fn(
hidden, alphas, token_num, mask=None
)
acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
if target_length is None and self.tail_threshold > 0.0:
token_num_int = torch.max(token_num).type(torch.int32).item()
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
return acoustic_embeds, token_num, alphas, cif_peak
def forward_chunk(self, hidden, cache=None, **kwargs):
is_final = kwargs.get("is_final", False)
batch_size, len_time, hidden_size = hidden.shape
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
output = torch.relu(self.cif_conv1d(queries))
output = output.transpose(1, 2)
output = self.cif_output(output)
alphas = torch.sigmoid(output)
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
alphas = alphas.squeeze(-1)
token_length = []
list_fires = []
list_frames = []
cache_alphas = []
cache_hiddens = []
if cache is not None and "chunk_size" in cache:
alphas[:, : cache["chunk_size"][0]] = 0.0
if not is_final:
alphas[:, sum(cache["chunk_size"][:2]) :] = 0.0
if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device)
cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device)
hidden = torch.cat((cache["cif_hidden"], hidden), dim=1)
alphas = torch.cat((cache["cif_alphas"], alphas), dim=1)
if cache is not None and is_final:
tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device)
tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device)
tail_alphas = torch.tile(tail_alphas, (batch_size, 1))
hidden = torch.cat((hidden, tail_hidden), dim=1)
alphas = torch.cat((alphas, tail_alphas), dim=1)
len_time = alphas.shape[1]
for b in range(batch_size):
integrate = 0.0
frames = torch.zeros((hidden_size), device=hidden.device)
list_frame = []
list_fire = []
for t in range(len_time):
alpha = alphas[b][t]
if alpha + integrate < self.threshold:
integrate += alpha
list_fire.append(integrate)
frames += alpha * hidden[b][t]
else:
frames += (self.threshold - integrate) * hidden[b][t]
list_frame.append(frames)
integrate += alpha
list_fire.append(integrate)
integrate -= self.threshold
frames = integrate * hidden[b][t]
cache_alphas.append(integrate)
if integrate > 0.0:
cache_hiddens.append(frames / integrate)
else:
cache_hiddens.append(frames)
token_length.append(torch.tensor(len(list_frame), device=alphas.device))
list_fires.append(list_fire)
list_frames.append(list_frame)
cache["cif_alphas"] = torch.stack(cache_alphas, axis=0)
cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0)
cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0)
cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0)
max_token_len = max(token_length)
if max_token_len == 0:
return hidden, torch.stack(token_length, 0), None, None
list_ls = []
for b in range(batch_size):
pad_frames = torch.zeros(
(max_token_len - token_length[b], hidden_size), device=alphas.device
)
if token_length[b] == 0:
list_ls.append(pad_frames)
else:
list_frames[b] = torch.stack(list_frames[b])
list_ls.append(torch.cat((list_frames[b], pad_frames), dim=0))
cache["cif_alphas"] = torch.stack(cache_alphas, axis=0)
cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0)
cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0)
cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0)
return torch.stack(list_ls, 0), torch.stack(token_length, 0), None, None
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
b, t, d = hidden.size()
tail_threshold = self.tail_threshold
if mask is not None:
zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
ones_t = torch.ones_like(zeros_t)
mask_1 = torch.cat([mask, zeros_t], dim=1)
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
alphas = torch.cat([alphas, zeros_t], dim=1)
alphas = torch.add(alphas, tail_threshold)
else:
tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
tail_threshold = torch.reshape(tail_threshold, (1, 1))
if b > 1:
alphas = torch.cat([alphas, tail_threshold.repeat(b, 1)], dim=1)
else:
alphas = torch.cat([alphas, tail_threshold], dim=1)
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
hidden = torch.cat([hidden, zeros], dim=1)
token_num = alphas.sum(dim=-1)
token_num_floor = torch.floor(token_num)
return hidden, alphas, token_num_floor
def gen_frame_alignments(
self, alphas: torch.Tensor = None, encoder_sequence_length: torch.Tensor = None
):
batch_size, maximum_length = alphas.size()
int_type = torch.int32
is_training = self.training
if is_training:
token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
else:
token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
max_token_num = torch.max(token_num).item()
alphas_cumsum = torch.cumsum(alphas, dim=1)
alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
index = torch.ones([batch_size, max_token_num], dtype=int_type)
index = torch.cumsum(index, dim=1)
index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
index_div_bool_zeros = index_div.eq(0)
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
index_div_bool_zeros_count = torch.clamp(
index_div_bool_zeros_count, 0, encoder_sequence_length.max()
)
token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
index_div_bool_zeros_count *= token_num_mask
index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(
1, 1, maximum_length
)
ones = torch.ones_like(index_div_bool_zeros_count_tile)
zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
ones = torch.cumsum(ones, dim=2)
cond = index_div_bool_zeros_count_tile == ones
index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
predictor_mask = (
(~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max()))
.type(int_type)
.to(encoder_sequence_length.device)
)
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
predictor_alignments = index_div_bool_zeros_count_tile_out
predictor_alignments_length = predictor_alignments.sum(-1).type(
encoder_sequence_length.dtype
)
return predictor_alignments.detach(), predictor_alignments_length.detach()
@tables.register("predictor_classes", "CifPredictorV2Export")
class CifPredictorV2Export(torch.nn.Module):
def __init__(self, model, **kwargs):
super().__init__()
self.pad = model.pad
self.cif_conv1d = model.cif_conv1d
self.cif_output = model.cif_output
self.threshold = model.threshold
self.smooth_factor = model.smooth_factor
self.noise_threshold = model.noise_threshold
self.tail_threshold = model.tail_threshold
def forward(
self,
hidden: torch.Tensor,
mask: torch.Tensor,
):
alphas, token_num = self.forward_cnn(hidden, mask)
mask = mask.transpose(-1, -2).float()
mask = mask.squeeze(-1)
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
acoustic_embeds, cif_peak = cif_v1_export(hidden, alphas, self.threshold)
return acoustic_embeds, token_num, alphas, cif_peak
def forward_cnn(
self,
hidden: torch.Tensor,
mask: torch.Tensor,
):
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
output = torch.relu(self.cif_conv1d(queries))
output = output.transpose(1, 2)
output = self.cif_output(output)
alphas = torch.sigmoid(output)
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
mask = mask.transpose(-1, -2).float()
alphas = alphas * mask
alphas = alphas.squeeze(-1)
token_num = alphas.sum(-1)
return alphas, token_num
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
b, t, d = hidden.size()
tail_threshold = self.tail_threshold
zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
ones_t = torch.ones_like(zeros_t)
mask_1 = torch.cat([mask, zeros_t], dim=1)
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
alphas = torch.cat([alphas, zeros_t], dim=1)
alphas = torch.add(alphas, tail_threshold)
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
hidden = torch.cat([hidden, zeros], dim=1)
token_num = alphas.sum(dim=-1)
token_num_floor = torch.floor(token_num)
return hidden, alphas, token_num_floor
@torch.jit.script
def cif_v1_export(hidden, alphas, threshold: float):
device = hidden.device
dtype = hidden.dtype
batch_size, len_time, hidden_size = hidden.size()
threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
# prefix_sum = torch.cumsum(alphas, dim=1)
prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(
torch.float32
) # cumsum precision degradation cause wrong result in extreme
prefix_sum_floor = torch.floor(prefix_sum)
dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
dislocation_prefix_sum_floor[:, 0] = 0
dislocation_diff = prefix_sum_floor - dislocation_prefix_sum_floor
fire_idxs = dislocation_diff > 0
fires[fire_idxs] = 1
fires = fires + prefix_sum - prefix_sum_floor
# prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).repeat((1, 1, hidden_size)) * hidden, dim=1)
frames = prefix_sum_hidden[fire_idxs]
shift_frames = torch.roll(frames, 1, dims=0)
batch_len = fire_idxs.sum(1)
batch_idxs = torch.cumsum(batch_len, dim=0)
shift_batch_idxs = torch.roll(batch_idxs, 1, dims=0)
shift_batch_idxs[0] = 0
shift_frames[shift_batch_idxs] = 0
remains = fires - torch.floor(fires)
# remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
remain_frames = remains[fire_idxs].unsqueeze(-1).repeat((1, hidden_size)) * hidden[fire_idxs]
shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
shift_remain_frames[shift_batch_idxs] = 0
frames = frames - shift_frames + shift_remain_frames - remain_frames
# max_label_len = batch_len.max()
max_label_len = alphas.sum(dim=-1)
max_label_len = torch.floor(max_label_len).max().to(dtype=torch.int64)
# frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
frame_fires_idxs = indices < batch_len.unsqueeze(1)
frame_fires[frame_fires_idxs] = frames
return frame_fires, fires
@torch.jit.script
def cif_export(hidden, alphas, threshold: float):
batch_size, len_time, hidden_size = hidden.size()
threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
# loop varss
integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=hidden.device)
frame = torch.zeros([batch_size, hidden_size], dtype=hidden.dtype, device=hidden.device)
# intermediate vars along time
list_fires = []
list_frames = []
for t in range(len_time):
alpha = alphas[:, t]
distribution_completion = (
torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device) - integrate
)
integrate += alpha
list_fires.append(integrate)
fire_place = integrate >= threshold
integrate = torch.where(
fire_place,
integrate - torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device),
integrate,
)
cur = torch.where(fire_place, distribution_completion, alpha)
remainds = alpha - cur
frame += cur[:, None] * hidden[:, t, :]
list_frames.append(frame)
frame = torch.where(
fire_place[:, None].repeat(1, hidden_size), remainds[:, None] * hidden[:, t, :], frame
)
fires = torch.stack(list_fires, 1)
frames = torch.stack(list_frames, 1)
fire_idxs = fires >= threshold
frame_fires = torch.zeros_like(hidden)
max_label_len = frames[0, fire_idxs[0]].size(0)
for b in range(batch_size):
frame_fire = frames[b, fire_idxs[b]]
frame_len = frame_fire.size(0)
frame_fires[b, :frame_len, :] = frame_fire
if frame_len >= max_label_len:
max_label_len = frame_len
frame_fires = frame_fires[:, :max_label_len, :]
return frame_fires, fires
class mae_loss(torch.nn.Module):
def __init__(self, normalize_length=False):
super(mae_loss, self).__init__()
self.normalize_length = normalize_length
self.criterion = torch.nn.L1Loss(reduction="sum")
def forward(self, token_length, pre_token_length):
loss_token_normalizer = token_length.size(0)
if self.normalize_length:
loss_token_normalizer = token_length.sum().type(torch.float32)
loss = self.criterion(token_length, pre_token_length)
loss = loss / loss_token_normalizer
return loss
def cif(hidden, alphas, threshold):
batch_size, len_time, hidden_size = hidden.size()
# loop varss
integrate = torch.zeros([batch_size], device=hidden.device)
frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
# intermediate vars along time
list_fires = []
list_frames = []
for t in range(len_time):
alpha = alphas[:, t]
distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate
integrate += alpha
list_fires.append(integrate)
fire_place = integrate >= threshold
integrate = torch.where(
fire_place, integrate - torch.ones([batch_size], device=hidden.device), integrate
)
cur = torch.where(fire_place, distribution_completion, alpha)
remainds = alpha - cur
frame += cur[:, None] * hidden[:, t, :]
list_frames.append(frame)
frame = torch.where(
fire_place[:, None].repeat(1, hidden_size), remainds[:, None] * hidden[:, t, :], frame
)
fires = torch.stack(list_fires, 1)
frames = torch.stack(list_frames, 1)
list_ls = []
len_labels = torch.round(alphas.sum(-1)).int()
max_label_len = len_labels.max()
for b in range(batch_size):
fire = fires[b, :]
l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
pad_l = torch.zeros([max_label_len - l.size(0), hidden_size], device=hidden.device)
list_ls.append(torch.cat([l, pad_l], 0))
return torch.stack(list_ls, 0), fires
def cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=False):
batch_size, len_time = alphas.size()
device = alphas.device
dtype = alphas.dtype
threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
if torch.cuda.get_device_name() == "Iluvatar BI-V150":
prefix_sum = torch.cumsum(alphas, dim=1)
else:
prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(
torch.float32
) # cumsum precision degradation cause wrong result in extreme
prefix_sum_floor = torch.floor(prefix_sum)
dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
dislocation_prefix_sum_floor[:, 0] = 0
dislocation_diff = prefix_sum_floor - dislocation_prefix_sum_floor
fire_idxs = dislocation_diff > 0
fires[fire_idxs] = 1
fires = fires + prefix_sum - prefix_sum_floor
if return_fire_idxs:
return fires, fire_idxs
return fires
def cif_v1(hidden, alphas, threshold):
fires, fire_idxs = cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=True)
device = hidden.device
dtype = hidden.dtype
batch_size, len_time, hidden_size = hidden.size()
# frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
# prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).repeat((1, 1, hidden_size)) * hidden, dim=1)
frames = prefix_sum_hidden[fire_idxs]
shift_frames = torch.roll(frames, 1, dims=0)
batch_len = fire_idxs.sum(1)
batch_idxs = torch.cumsum(batch_len, dim=0)
shift_batch_idxs = torch.roll(batch_idxs, 1, dims=0)
shift_batch_idxs[0] = 0
shift_frames[shift_batch_idxs] = 0
remains = fires - torch.floor(fires)
# remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
remain_frames = remains[fire_idxs].unsqueeze(-1).repeat((1, hidden_size)) * hidden[fire_idxs]
shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
shift_remain_frames[shift_batch_idxs] = 0
frames = frames - shift_frames + shift_remain_frames - remain_frames
# max_label_len = batch_len.max()
max_label_len = (
torch.round(alphas.sum(-1)).int().max()
) # torch.round to calculate the max length
# frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
frame_fires_idxs = indices < batch_len.unsqueeze(1)
frame_fires[frame_fires_idxs] = frames
return frame_fires, fires
def cif_wo_hidden(alphas, threshold):
batch_size, len_time = alphas.size()
# loop varss
integrate = torch.zeros([batch_size], device=alphas.device)
# intermediate vars along time
list_fires = []
for t in range(len_time):
alpha = alphas[:, t]
integrate += alpha
list_fires.append(integrate)
fire_place = integrate >= threshold
integrate = torch.where(
fire_place,
integrate - torch.ones([batch_size], device=alphas.device) * threshold,
integrate,
)
fires = torch.stack(list_fires, 1)
return fires

View File

@@ -0,0 +1,746 @@
import logging
import os
import random
import re
import string
import time
import traceback
from typing import Union
import torch
import torch.nn as nn
from funasr.metrics.compute_acc import compute_accuracy
from funasr.register import tables
from funasr.train_utils.device_funcs import force_gatherable, to_device
from funasr.utils.datadir_writer import DatadirWriter
from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
from transformers import AutoConfig, AutoModelForCausalLM
from funasr.models.fun_asr_nano.ctc import CTC
from funasr.models.fun_asr_nano.tools.utils import forced_align
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
@tables.register("model_classes", "FunASRNano")
class FunASRNano(nn.Module):
def __init__(
self,
audio_encoder: str = None,
audio_encoder_conf: dict = None,
audio_adaptor: str = None,
audio_adaptor_conf: dict = None,
llm: str = None,
llm_conf: dict = None,
input_size: int = 80,
length_normalized_loss: bool = False,
**kwargs,
):
super().__init__()
# audio encoder
hub = audio_encoder_conf.get("hub", None)
self.audio_encoder_activation_checkpoint = audio_encoder_conf.get(
"activation_checkpoint", False
)
if hub == "ms":
from funasr import AutoModel
model = AutoModel(model=audio_encoder, model_revision="master")
audio_encoder_output_size = (
model.model.encoder_output_size
if hasattr(model.model, "encoder_output_size")
else -1
)
audio_encoder = (
model.model.model.encoder if hasattr(model.model, "model") else model.model.encoder
)
else:
encoder_class = tables.encoder_classes.get(audio_encoder)
audio_encoder = encoder_class(input_size=input_size, **audio_encoder_conf)
audio_encoder_output_size = audio_encoder.output_size()
freeze = audio_encoder_conf.get("freeze", True)
if freeze:
for _, param in audio_encoder.named_parameters():
param.requires_grad = False
audio_encoder.eval()
self.audio_encoder = audio_encoder
# llm
self.llm = None
init_param_path = llm_conf.get("init_param_path", None)
llm_dim = None
llm_load_kwargs = llm_conf.get("load_kwargs", {})
config = AutoConfig.from_pretrained(init_param_path)
model = AutoModelForCausalLM.from_config(config, **llm_load_kwargs)
freeze = llm_conf.get("freeze", True)
if freeze:
for _, param in model.named_parameters():
param.requires_grad = False
model.eval()
if llm_conf.get("activation_checkpoint", False):
model.gradient_checkpointing_enable()
self.llm_dtype = llm_conf.get("llm_dtype", "fp32")
self.llm = model.to(dtype_map[self.llm_dtype])
llm_dim = model.get_input_embeddings().weight.shape[-1]
# adaptor
adaptor_class = tables.adaptor_classes.get(audio_adaptor)
if audio_encoder_output_size > 0:
audio_adaptor_conf["encoder_dim"] = audio_encoder_output_size
audio_adaptor_conf["llm_dim"] = (
llm_dim if llm_dim is not None else audio_adaptor_conf["llm_dim"]
)
audio_adaptor = adaptor_class(**audio_adaptor_conf)
freeze = audio_adaptor_conf.get("freeze", False)
if freeze:
for _, param in audio_adaptor.named_parameters():
param.requires_grad = False
audio_adaptor.eval()
self.audio_adaptor = audio_adaptor
self.use_low_frame_rate = audio_adaptor_conf.get("use_low_frame_rate", False)
# ctc decoder
self.ctc_decoder = None
# TODO: fix table name
ctc_decoder_class = tables.adaptor_classes.get(kwargs.get("ctc_decoder", None))
if ctc_decoder_class is not None:
ctc_tokenizer = (
kwargs.get("ctc_tokenizer", None)
if "ctc_tokenizer" in kwargs
else kwargs["dataset_conf"]["ctc_tokenizer"]
)
ctc_tokenizer_conf = (
kwargs.get("ctc_tokenizer_conf", None)
if "ctc_tokenizer_conf" in kwargs
else kwargs["dataset_conf"]["ctc_tokenizer_conf"]
)
if ctc_tokenizer is not None and ctc_tokenizer_conf is not None:
ctc_tokenizer_class = tables.tokenizer_classes.get(ctc_tokenizer)
ctc_tokenizer = ctc_tokenizer_class(**ctc_tokenizer_conf)
self.ctc_tokenizer = ctc_tokenizer
assert ctc_tokenizer is not None, f"ctc_tokenizer must be set"
ctc_vocab_size = kwargs.get("ctc_vocab_size", 60515)
ctc_decoder_conf = kwargs.get("ctc_decoder_conf", {})
if audio_encoder_output_size > 0:
ctc_decoder_conf["encoder_dim"] = audio_encoder_output_size
self.ctc_decoder = ctc_decoder_class(**ctc_decoder_conf)
init_param_path = ctc_decoder_conf.get("init_param_path", None)
if init_param_path is not None:
src_state = torch.load(init_param_path, map_location="cpu")
flag = self.ctc_decoder.load_state_dict(src_state, strict=False)
logging.info(f"Loading ctc_decoder ckpt: {init_param_path}, status: {flag}")
freeze = ctc_decoder_conf.get("freeze", False)
if freeze:
for _, param in self.ctc_decoder.named_parameters():
param.requires_grad = False
self.ctc_decoder.eval()
ctc_conf = kwargs.get("ctc_conf", {})
self.blank_id = ctc_conf.get("blank_id", ctc_vocab_size - 1)
self.ctc_weight = kwargs.get("ctc_weight", 0.3)
self.ctc = CTC(
odim=ctc_vocab_size,
encoder_output_size=audio_encoder_output_size,
blank_id=self.blank_id,
**ctc_conf,
)
self.detach_ctc_decoder = kwargs.get("detach_ctc_decoder", True)
self.error_calculator = None
self.length_normalized_loss = length_normalized_loss
rank = int(os.environ.get("RANK", 0))
logging.info(f"rank: {rank}, model is builded.")
def forward(
self,
speech: torch.Tensor = None,
speech_lengths: torch.Tensor = None,
input_ids: torch.Tensor = None,
attention_mask: torch.Tensor = None,
labels_ids: torch.Tensor = None,
fbank_beg: torch.Tensor = None,
fbank_mask: torch.Tensor = None,
**kwargs,
):
batch_size, token_num = input_ids.shape
stats = {}
input_ids[input_ids < 0] = 0
inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
if speech is not None:
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size_speech, frames, _ = speech.shape
# audio encoder
if self.audio_encoder_activation_checkpoint:
from torch.utils.checkpoint import checkpoint
encoder_out, encoder_out_lens = checkpoint(
self.encode, speech, speech_lengths, use_reentrant=False
)
else:
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# audio_adaptor
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
batch_size, token_num, dims = inputs_embeds.shape
fake_token_len = kwargs.get("fake_token_len")
fake_token_len[fake_token_len < 0] = 0
fbank_beg[fbank_beg < 0] = 0
speech_idx = 0
for batch_idx in range(batch_size):
for turn_id in range(fbank_beg.shape[1]):
fbank_beg_idx = fbank_beg[batch_idx, turn_id].item()
if fbank_beg_idx > 0:
speech_token_len = fake_token_len[batch_idx, turn_id]
speech_token = encoder_out[speech_idx, :speech_token_len, :]
try:
inputs_embeds[
batch_idx,
fbank_beg_idx : fbank_beg_idx + speech_token_len,
:,
] = speech_token
except Exception as e:
logging.error(f"{str(e)}, {traceback.format_exc()}")
logging.info(
f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}"
)
speech_token_len = encoder_out_lens[speech_idx].item()
speech_token = encoder_out[speech_idx, :speech_token_len, :]
inputs_embeds[
batch_idx,
fbank_beg_idx : fbank_beg_idx + speech_token_len,
:,
] = speech_token
speech_idx += 1
stats["batch_size_speech"] = batch_size_speech
stats["batch_size_x_frames"] = frames * batch_size_speech
stats["batch_size_real_frames"] = speech_lengths.sum().item()
stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
device_type = next(self.parameters()).device.type
with torch.autocast(
device_type=device_type if device_type in ["cuda", "xpu", "mps"] else "cpu",
enabled=True if self.llm_dtype != "fp32" else False,
dtype=dtype_map[self.llm_dtype],
):
labels_ids[labels_ids == -1] = -100
attention_mask[attention_mask < 0] = 0
model_outputs = self.llm(
inputs_embeds=inputs_embeds.to(dtype_map[self.llm_dtype]),
attention_mask=attention_mask,
labels=labels_ids,
)
loss = model_outputs.loss
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
stats["acc"] = acc_att
stats["loss"] = torch.clone(loss.detach())
stats["batch_size"] = batch_size
stats["batch_size_x_tokens"] = token_num * batch_size
stats["batch_size_real_tokens"] = attention_mask.sum().item()
stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
dialog_turns = (fbank_beg > 0).sum(-1)
dialog_turns_max = torch.max(dialog_turns).int().item()
dialog_turns_avg = dialog_turns.sum().item() / batch_size
stats["dialog_turns_max"] = dialog_turns_max
stats["dialog_turns_avg"] = dialog_turns_avg
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((labels_ids > 0 + 1).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def forward_export(self, speech, speech_lengths, **kwargs):
x, olens = self.audio_encoder(speech, speech_lengths)
encoder_out, encoder_out_lens = self.audio_adaptor(x, olens)
return encoder_out, encoder_out_lens
def encode(self, speech, speech_lengths):
# audio encoder
encoder_out, encoder_out_lens = self.audio_encoder(speech, speech_lengths)
return encoder_out, encoder_out_lens
def data_template(self, data):
system, user, assistant = [], [], []
for i, item in enumerate(data):
role = item["role"]
content = item["content"]
if role == "system":
system.append(content)
elif role == "user":
if "audio" in item:
audio = item["audio"]
content = [content, audio]
user.append(content)
elif role == "assistant":
assistant.append(content)
system = system * len(user)
contents = {
"system": system,
"user": user,
"assistant": assistant,
}
return contents
def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs):
system = contents["system"]
user = contents["user"]
assistant = contents["assistant"]
pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)")
do_think = True
sys_prompt = True
if "dataset_conf" in kwargs:
do_think = kwargs["dataset_conf"].get("do_think", True)
sys_prompt = kwargs["dataset_conf"].get("sys_prompt", True)
input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg, fake_token_len = (
[],
[],
[],
[],
[],
[],
[],
)
input_source_ids = []
for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)):
if i >= kwargs.get("multiturn_num_max", 5):
break
if len(input_ids) > kwargs.get("max_token_length", 1500):
break
if isinstance(user_prompt, (list, tuple)):
user_prompt, audio = user_prompt
if i == 0:
if kwargs.get("infer_with_assistant_input", False):
source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}"
if not sys_prompt:
source_input = f"<|im_start|>user\n{user_prompt}"
else:
source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
if not sys_prompt:
source_input = (
f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
)
else:
if kwargs.get("infer_with_assistant_input", False):
source_input = f"<|im_start|>user\n{user_prompt}"
else:
source_input = (
f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
)
if not do_think:
source_input += "<think>\n\n</think>\n\n"
if kwargs.get("prev_text", None) is not None:
source_input += kwargs["prev_text"]
splits = pattern.split(source_input)
source_ids = []
fbank_mask_i = []
fake_token_len_i = 0
fbank_beg_i = -1
speech, speech_lengths = [], []
for k, sub_str in enumerate(splits):
if not sub_str.startswith("<|startofspeech|>"):
sub_token = tokenizer.encode(sub_str)
source_ids += sub_token
fbank_mask_i += [0] * len(sub_token)
else:
sub_str = sub_str.replace("<|startofspeech|>", "").replace(
"<|endofspeech|>", ""
)
if sub_str.startswith("!"):
sub_str = sub_str[1:]
if sub_str.startswith("!"): # !!: audio sample point
sub_str = audio
try:
time1 = time.perf_counter()
data_src = load_audio_text_image_video(
sub_str, fs=frontend.fs, **kwargs
)
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
except Exception as e:
logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}")
speech, speech_lengths = extract_fbank(
data_src,
data_type=kwargs.get("data_type", "sound"),
frontend=frontend,
is_final=True,
) # speech: [b, T, d]
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = (
speech_lengths.sum().item()
* frontend.frame_shift
* frontend.lfr_n
/ 1000
)
if self.use_low_frame_rate:
olens = 1 + (speech_lengths[0].item() - 3 + 2 * 1) // 2
olens = 1 + (olens - 3 + 2 * 1) // 2
fake_token_len_i = (olens - 1) // 2 + 1
else:
fake_token_len_i = speech_lengths[0].item()
fake_token = [0] * fake_token_len_i
fbank_beg_i = len(source_ids)
source_ids += fake_token
fbank_mask_i += [1] * len(fake_token)
fbank_beg += [fbank_beg_i + len(input_ids)]
fake_token_len += [fake_token_len_i]
source_mask = [-100] * len(source_ids)
target_out = f"{target_out}<|im_end|>"
target_ids = tokenizer.encode(target_out)
input_source_ids = input_ids + source_ids
input_ids += source_ids + target_ids
labels += source_mask + target_ids
fbank_mask += fbank_mask_i
if len(speech) > 0:
fbank.append(speech[0, :, :])
fbank_lens.append(speech_lengths)
input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length]
attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length]
fbank_mask = torch.tensor(fbank_mask, dtype=torch.float32)
fbank_beg = torch.tensor(fbank_beg, dtype=torch.int32)
fake_token_len = torch.tensor(fake_token_len, dtype=torch.int32)
source_ids = torch.tensor(input_source_ids, dtype=torch.int64)
target_ids = torch.tensor(target_ids, dtype=torch.int64)
if len(fbank) > 0:
speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0)
speech_lengths = torch.nn.utils.rnn.pad_sequence(
fbank_lens, batch_first=True, padding_value=-1
)
else:
speech = []
speech_lengths = []
output = {
"speech": speech,
"speech_lengths": speech_lengths,
"fbank_mask": fbank_mask[None, :],
"fbank_beg": fbank_beg[None,],
"fake_token_len": fake_token_len[None, :],
"input_ids": input_ids[None,],
"attention_mask": attention_mask[None,],
"labels_ids": labels,
"source_ids": source_ids[None, :],
"target_ids": target_ids[None, :],
}
return output
def inference_prepare(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
meta_data = {}
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
contents = self.data_template(data_in[0])
output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs)
batch = to_device(output, kwargs["device"])
# audio encoder
speech = batch["speech"]
if len(speech) > 0:
if "audio_embedding" in kwargs and "audio_embedding_lens" in kwargs:
encoder_out = kwargs["audio_embedding"]
encoder_out_lens = kwargs["audio_embedding_lens"]
else:
speech_lengths = batch["speech_lengths"][:, 0]
# fp16
if kwargs.get("fp16", False):
speech = speech.to(torch.float16)
elif kwargs.get("bf16", False):
speech = speech.to(torch.bfloat16)
# audio encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# audio_adaptor
adaptor_out, adaptor_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
meta_data["encoder_out"] = encoder_out
meta_data["encoder_out_lens"] = encoder_out_lens
meta_data["audio_adaptor_out"] = adaptor_out
meta_data["audio_adaptor_out_lens"] = adaptor_out_lens
input_ids = batch["input_ids"]
source_ids = batch["source_ids"]
fbank_beg = batch["fbank_beg"]
fake_token_len = batch["fake_token_len"]
if not kwargs.get("teacherforcing", False):
input_ids = source_ids
input_ids[input_ids < 0] = 0
inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
batch_size, token_num, dims = inputs_embeds.shape
fake_token_len[fake_token_len < 0] = 0
fbank_beg[fbank_beg < 0] = 0
speech_idx = 0
for batch_idx in range(batch_size):
for turn_id in range(fbank_beg.shape[1]):
fbank_beg_idx = fbank_beg[batch_idx, turn_id].item()
if fbank_beg_idx > 0:
speech_token_len = fake_token_len[batch_idx, turn_id]
speech_token = adaptor_out[speech_idx, :speech_token_len, :]
try:
inputs_embeds[
batch_idx,
fbank_beg_idx : fbank_beg_idx + speech_token_len,
:,
] = speech_token
except Exception as e:
#
logging.error(f"{str(e)}, {traceback.format_exc()}")
logging.info(
f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, adaptor_out: {adaptor_out.shape}, adaptor_out_lens: {adaptor_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}"
)
speech_token_len = adaptor_out_lens[speech_idx].item()
speech_token = adaptor_out[speech_idx, :speech_token_len, :]
inputs_embeds[
batch_idx,
fbank_beg_idx : fbank_beg_idx + speech_token_len,
:,
] = speech_token
speech_idx += 1
return inputs_embeds, contents, batch, source_ids, meta_data
def get_prompt(self, hotwords: list[str], language: str = None, itn: bool = True):
if len(hotwords) > 0:
hotwords = ", ".join(hotwords)
prompt = f"请结合上下文信息,更加准确地完成语音转写任务。如果没有相关信息,我们会留空。\n\n\n**上下文信息:**\n\n\n"
prompt += f"热词列表:[{hotwords}]\n"
else:
prompt = ""
if language is None:
prompt += "语音转写"
else:
prompt += f"语音转写成{language}"
if not itn:
prompt += ",不进行文本规整"
return prompt + ""
def generate_chatml(self, prompt: str, data: Union[str, torch.Tensor]):
if isinstance(data, str):
return [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": f"{prompt}<|startofspeech|>!{data}<|endofspeech|>"},
{"role": "assistant", "content": "null"},
]
elif isinstance(data, torch.Tensor):
return [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": f"{prompt}<|startofspeech|>!!<|endofspeech|>",
"audio": data,
},
{"role": "assistant", "content": "null"},
]
def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
prompt = self.get_prompt(
kwargs.get("hotwords", []), kwargs.get("language", None), kwargs.get("itn", True)
)
data_in = [self.generate_chatml(prompt, data) for data in data_in]
if key is None:
key = []
for _ in data_in:
chars = string.ascii_letters + string.digits
key.append("rand_key_" + "".join(random.choice(chars) for _ in range(13)))
return self.inference_llm(
data_in,
data_lengths=data_lengths,
key=key,
tokenizer=tokenizer,
frontend=frontend,
**kwargs,
)
def inference_llm(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
data_in, data_lengths, key, tokenizer, frontend, **kwargs
)
ctc_results = []
if self.ctc_decoder is not None:
encoder_out = meta_data["encoder_out"]
encoder_out_lens = meta_data["encoder_out_lens"]
decoder_out, decoder_out_lens = self.ctc_decoder(encoder_out, encoder_out_lens)
ctc_logits = self.ctc.log_softmax(decoder_out)
b, n, d = encoder_out.size()
if isinstance(key[0], (list, tuple)):
key = key[0]
if len(key) < b:
key = key * b
for i in range(b):
x = ctc_logits[i, : encoder_out_lens[i].item(), :]
yseq = x.argmax(dim=-1)
yseq = torch.unique_consecutive(yseq, dim=-1)
mask = yseq != self.blank_id
token_int = yseq[mask].tolist()
# Change integer-ids to tokens
text = self.ctc_tokenizer.decode(token_int)
ctc_results.append({"key": key[i], "text": text, "ctc_logits": x})
llm_dtype = kwargs.get("llm_dtype", "fp32")
if llm_dtype == "fp32":
llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype
llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype
device_type = torch.device(kwargs.get("device", "cuda")).type
with torch.autocast(
device_type=device_type if device_type in ["cuda", "xpu", "mps"] else "cpu",
enabled=True if llm_dtype != "fp32" else False,
dtype=dtype_map[llm_dtype],
):
label = contents["assistant"][-1]
self.llm = self.llm.to(dtype_map[llm_dtype])
inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
llm_kwargs = kwargs.get("llm_kwargs", {})
if not kwargs.get("teacherforcing", False):
attention_mask = batch.get("attention_mask", None)
generated_ids = self.llm.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
max_new_tokens=kwargs.get("max_length", 512),
pad_token_id=self.llm.config.pad_token_id or self.llm.config.eos_token_id,
**llm_kwargs,
)
response = tokenizer.batch_decode(
generated_ids,
skip_special_tokens=kwargs.get("skip_special_tokens", True),
)[0]
loss = None
else:
labels_ids = batch["labels_ids"]
labels_ids[labels_ids == -1] = -100
attention_mask = batch.get("attention_mask", None)
model_outputs = self.llm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels_ids,
pad_token_id=self.llm.config.pad_token_id or self.llm.config.eos_token_id,
**llm_kwargs,
)
preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :]
response = tokenizer.batch_decode(
preds,
add_special_tokens=False,
skip_special_tokens=kwargs.get("skip_special_tokens", True),
)[0]
loss = model_outputs.loss.item()
response = kwargs.get("prev_text", "") + response
ibest_writer = None
if kwargs.get("output_dir") is not None:
if not hasattr(self, "writer"):
self.writer = DatadirWriter(kwargs.get("output_dir"))
ibest_writer = self.writer[f"{0 + 1}best_recog"]
results = []
response_clean = re.sub(r"[^\w\s\u3000\u4e00-\u9fff]+", "", response)
result_i = {
"key": key[0],
"text": re.sub(r"\s+", " ", response.replace("/sil", " ")),
"text_tn": response_clean,
"label": label,
}
if loss is not None:
result_i["loss"] = loss
results.append(result_i)
for ctc_result, result in zip(ctc_results, results):
result["ctc_text"] = ctc_result["text"].replace("<|nospeech|>", "")
target_ids = torch.tensor(
self.ctc_tokenizer.encode(result["ctc_text"]), dtype=torch.int64
)
result["ctc_timestamps"] = forced_align(
ctc_result["ctc_logits"], target_ids, self.blank_id
)
target_ids = torch.tensor(self.ctc_tokenizer.encode(result["text"]), dtype=torch.int64)
result["timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id)
for timestamps in [result["timestamps"], result["ctc_timestamps"]]:
for timestamp in timestamps:
timestamp["token"] = self.ctc_tokenizer.decode([timestamp["token"]])
timestamp["start_time"] = timestamp["start_time"] * 6 * 10 / 1000
timestamp["end_time"] = timestamp["end_time"] * 6 * 10 / 1000
if ibest_writer is not None:
ibest_writer["text"][key[0]] = response.replace("\n", " ")
ibest_writer["label"][key[0]] = label.replace("\n", " ")
ibest_writer["text_tn"][key[0]] = response_clean
return results, meta_data
@staticmethod
def from_pretrained(model: str = None, **kwargs):
from funasr import AutoModel
model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs)
return model, kwargs

14
funasr/requirements.txt Normal file
View File

@@ -0,0 +1,14 @@
requests
wheel
websocket-client
pydantic>=2.0.0
numpy<2.0
PYYaml
Levenshtein
ruamel.yaml
nltk==3.7
pynini==2.1.6
soundfile
fastapi
uvicorn
python-multipart

BIN
funasr/warmup.wav Normal file

Binary file not shown.

Binary file not shown.

1
sample_data/lei-jun.txt Normal file
View File

@@ -0,0 +1 @@
朋友们晚上好欢迎大家来参加今天晚上的活动谢谢大家。这是我第四次办年度演讲前三次呢因为疫情的原因都在小米科技园内举办。现场呢人很少。这是第四次我们仔细想了想我们还是想办一个比较大的聚会。然后呢让我们的新朋友老朋友一起聚一聚。今天的话呢我们就在北京的国家会议中心呢举办了这么一个活动。现场呢来了很多人大概有3500人。还有很多很多的朋友呢通过观看直播的方式来参与。再一次呢对大家的参加表示感谢谢谢大家。两个月前我参加了今年武汉大学的毕业典礼。今年呢是武汉大学建校130周年作为校友被母校邀请在毕业典礼上致辞这对我来说是至高无上的荣誉。站在讲台的那一刻面对全校师生关于武大的所有的记忆一下子涌现在脑海里。今天呢我就先和大家聊聊武大往事。那还是36年前1987年我呢考上了武汉大学的计算机系。在武汉大学的图书馆里看了一本书《硅谷之火》建立了我一生的梦想。看完书以后热血沸腾激动得睡不着觉。我还记得那天晚上星光很亮。我就在武大的操场上就是屏幕上这个操场走了一圈又一圈走了整整一个晚上。我心里有团火我也想办一个伟大的公司就是这样梦想之火在我心里彻底点燃了。但是一个大一的新生但是一个大一的新生一个从县城里出来的年轻人什么也不会什么也没有就想创办一家伟大的公司这不就是天方夜谭吗这么离谱的一个梦想该如何实现呢那天晚上我想了一整晚上说实话越想越糊涂完全理不清头绪后来我在想哎干脆别想了把书念好是正事所以呢我就下定决心认认真真读书那么我怎么能够把书读得不同凡响呢

View File

@@ -0,0 +1,25 @@
FROM corex:4.3.8
WORKDIR /root
RUN set -eux; \
# 1) 把 aliyun 源替换成官方源(避免 403
sed -i -E 's|http://mirrors\.aliyun\.com/ubuntu|http://archive.ubuntu.com/ubuntu|g' /etc/apt/sources.list; \
sed -i -E 's|http://mirrors\.aliyun\.com/ubuntu|http://archive.ubuntu.com/ubuntu|g' /etc/apt/sources.list.d/*.list 2>/dev/null || true; \
\
# 2) 更新并安装
apt-get update; \
apt-get install -y --no-install-recommends vim net-tools ca-certificates libasound2-dev patchelf; \
rm -rf /var/lib/apt/lists/*
ADD . /root/
COPY requirements.txt /root
RUN pip install -r requirements.txt -i https://nexus.4pd.io/repository/pypi-all/simple --extra-index-url https://mirror.sjtu.edu.cn/pypi/web/simple
COPY sherpa_onnx-1.12.5+corex4.3.8-cp310-cp310-linux_x86_64.whl /root
RUN pip install ./sherpa_onnx-1.12.5+corex4.3.8-cp310-cp310-linux_x86_64.whl
ENV LD_LIBRARY_PATH=/usr/local/corex-4.3.8/lib64/python3/dist-packages/tvm/:$LD_LIBRARY_PATH
ENTRYPOINT ["python3"]
CMD ["./main_sherpa.py"]

28
sherpa-onnx/README.md Normal file
View File

@@ -0,0 +1,28 @@
# 天数智芯 天垓150 ASRSherpa-ONNX架构
## 镜像构造
```shell
docker build -f ./Dockerfile.sherpa-onnx-bi150 -t <your_image> .
```
其中,基础镜像 corex:4.3.8 通过联系天数智芯智铠100厂商技术支持可获取
## 使用说明
### 使用 FastAPI 启动ASR服务
例如:
```shell
docker run -dit -v /usr/src:/usr/src -v /lib/modules:/lib/modules --device=/dev/iluvatar0:/dev/iluvatar0 \
-v /mnt/contest_ceph/leaderboard/modelHubXC/mariolux/sherpa-onnx-dolphin-small-ctc-multi-lang-2025-04-02:/model \
--network=host <your_image> \
main_sherpa.py --model_dir /model --model_type dolphon_ctc --offline_model --use_gpu --port 1111
```
具体参数代码设定可参考代码文件
### 测试ASR服务
项目根路径`sample_data`目录下附带上了中文的测试音频和附带内容
```shell
curl -X POST http://localhost:1111/transduce \
-F "audio=@../sample_data/lei-jun-test.wav" \
-F "lang=zh"
```

View File

@@ -0,0 +1,520 @@
import os
import sys
import time
import uuid
import datetime
import tempfile
import soundfile as sf
import sherpa_onnx
import traceback
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Form
os.makedirs("./input", exist_ok=True)
status = "Running"
recognizer = None
device = ""
model_type = ""
app = FastAPI()
CUSTOM_DEVICE = os.getenv("CUSTOM_DEVICE", "")
# 根据名称判断模型类型,比较杂,一共种类的自定义类型包括(针对OfflineRecognizer)
# moonshine
# fire_red
# dolphin_ctc
# paraformer
# telespeech_ctc
# whisper
# sensevoice
# zipformer_ctc
# transducer
# nemo_ctc
# nemo_canary
# wenet_ctc
# 针对OnlineRecognizer只有 zipformer_ctc transducer paraformer nemo_ctc wenet_ctc 四种
def get_asr_model_type(model_name):
# 根据名称判断模型类型以及需要检测的语种任务
# nemo_ctc, nemo_canary, moonshine 目前sherpa-onnx没有中文模型执行英文ASR任务其余模型执行中文ASR
# 所有nemo模型(nemo_ctc, nemo_canary以及transuducer中的nemo模型)均无中文模型
# 英文模型也并非全部大类都支持
# 特殊规则
# zipformer带ctc的才属于zipformer_ctc那一类否则属于transducer类
# nemo也是带上ctc或者canary才属于单独类别否则属于transducer类
# conformer均为transducer类,但是得在nemo之后判断
# wenet 由于同时wenetspeech为数据集名称各种类型都有可能这个逻辑需放在后面
model_type = "unknown"
model_name_lower = model_name.lower()
if "tdnn" in model_name_lower:
model_type = "tdnn" # tdnn类别不适用目前仅有一个模型只能识别希伯来语中的yes/no两种词语
elif "moonshine" in model_name_lower:
model_type = "moonshine"
elif "fire-red" in model_name_lower:
model_type = "fire_red"
elif "dolphin" in model_name_lower:
model_type = "dolphin_ctc"
elif "paraformer" in model_name_lower:
model_type = "paraformer"
elif "telespeech" in model_name_lower:
model_type = "telespeech_ctc"
elif "whisper" in model_name_lower:
model_type = "whisper"
elif "sense-voice" in model_name_lower:
model_type = "sensevoice"
elif "zipformer" in model_name_lower:
if "ctc" in model_name_lower:
model_type = "zipformer_ctc"
else:
model_type = "transducer"
elif "nemo" in model_name_lower:
if "ctc" in model_name_lower:
model_type = "nemo_ctc"
elif "canary" in model_name_lower:
model_type = "nemo_canary"
else:
model_type = "transducer"
elif "conformer" in model_name_lower or "lstm" in model_name_lower:
model_type = "transducer"
elif "wenet" in model_name_lower:
model_type = "wenet_ctc"
else:
model_type = "unknown"
return model_type
@app.on_event("startup")
def load_model():
global status, recognizer, device, model_type
config = app.state.config
use_gpu = config.get("use_gpu", True)
model_dir = config.get("model_dir", "/model")
_model_type = config.get("model_type", None)
_model_name = config.get("model_name", None)
warmup = config.get("warmup", False)
isOffline = config.get("offline_model", True)
num_threads = config.get("num_threads", 2)
device = "cpu"
if use_gpu:
if CUSTOM_DEVICE.startswith("mlu"):
device = "mlu:0"
elif CUSTOM_DEVICE.startswith("ascend"):
device = "npu:0"
else:
device = "cuda:0"
# sherpa-onnx类型繁杂当用户清楚的时候可提供model_type参数抑或是提供完整的模型名称也行
# 因为挂载进入镜像的时候镜像内的文件路径不一定包含了模型名称
if _model_type:
model_type = _model_type
elif _model_name:
model_type = get_asr_model_type(_model_name)
else:
print("model_name and model_type both not provided, start guessing using model_dir", flush=True)
model_name = os.path.basename(model_dir)
model_type = get_asr_model_type(model_name)
print(">> Startup config:")
print(" model_dir =", model_dir, flush=True)
print(" model_type =", model_type, flush=True)
print(" use_gpu =", use_gpu, flush=True)
print(" warmup =", warmup, flush=True)
print(" isOffline =", isOffline, flush=True)
print(" num_threads =", num_threads, flush=True)
try:
recognizer = None
provider = "cuda" if use_gpu else "cpu"
file_list = os.listdir(model_dir)
# 目录内的模型文件可能会有多套(例如量化和不带量化版),选取大小最大的那一套
if model_type == "whisper":
encoder_list, decoder_list = [], []
tokens = ""
for file in file_list:
if "encode" in file and file.endswith(".onnx"):
encoder_list.append(file)
elif "decode" in file and file.endswith(".onnx"):
decoder_list.append(file)
elif "token" in file and file.endswith(".txt"):
tokens = file
encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
decoder = sorted(decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
encoder=model_dir + "/" + encoder,
decoder=model_dir + "/" + decoder,
tokens=model_dir + "/" + tokens,
language="zh",
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "sensevoice":
model_list = []
tokens = ""
for file in file_list:
if file.endswith(".onnx"):
model_list.append(file)
elif file.endswith(".txt") and "token" in file:
tokens = file
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
use_itn=True,
language="zh",
provider=provider,
num_threads=num_threads
)
elif model_type == "paraformer":
model_list = []
for file in file_list:
if file.endswith(".onnx"):
model_list.append(file)
elif file.endswith(".txt") and "token" in file:
tokens = file
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
if isOffline:
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
paraformer=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
else:
recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer(
paraformer=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "zipformer_ctc":
model_list = []
for file in file_list:
if file.endswith(".onnx"):
model_list.append(file)
elif file.endswith(".txt") and "token" in file:
tokens = file
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
if isOffline:
recognizer = sherpa_onnx.OfflineRecognizer.from_zipformer_ctc(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
else:
recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "telespeech_ctc":
model_list = []
for file in file_list:
if file.endswith(".onnx"):
model_list.append(file)
elif file.endswith(".txt") and "token" in file:
tokens = file
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
recognizer = sherpa_onnx.OfflineRecognizer.from_telespeech_ctc(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "fire_red":
encoder_list, decoder_list = [], []
tokens = ""
for file in file_list:
if "encode" in file and file.endswith(".onnx"):
encoder_list.append(file)
elif "decode" in file and file.endswith(".onnx"):
decoder_list.append(file)
elif "token" in file and file.endswith(".txt"):
tokens = file
encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
decoder = sorted(decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
recognizer = sherpa_onnx.OfflineRecognizer.from_fire_red_asr(
encoder=model_dir + "/" + encoder,
decoder=model_dir + "/" + decoder,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "wenet_ctc":
model_list = []
for file in file_list:
if file.endswith(".onnx"):
model_list.append(file)
elif file.endswith(".txt") and "token" in file:
tokens = file
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
if isOffline:
recognizer = sherpa_onnx.OfflineRecognizer.from_wenet_ctc(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
else:
recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "dolphin_ctc":
model_list = []
for file in file_list:
if file.endswith(".onnx"):
model_list.append(file)
elif file.endswith(".txt") and "token" in file:
tokens = file
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
recognizer = sherpa_onnx.OfflineRecognizer.from_dolphin_ctc(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "transducer":
encoder_list, decoder_list, joiner_list = [], [], []
tokens = ""
for file in file_list:
if "encode" in file and file.endswith(".onnx"):
encoder_list.append(file)
elif "decode" in file and file.endswith(".onnx"):
decoder_list.append(file)
elif "joiner" in file and file.endswith(".onnx"):
joiner_list.append(file)
elif "token" in file and file.endswith(".txt"):
tokens = file
encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
decoder = sorted(decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
joiner = sorted(joiner_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
# 特殊情况zipformer,conformer都是icefall导出默认类型即可nemo-transducer需要专门区分
transducer_type = "nemo_transducer" if "nemo" in model_name.lower() else "transducer"
if isOffline:
recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
encoder=model_dir + "/" + encoder,
decoder=model_dir + "/" + decoder,
joiner=model_dir + "/" + joiner,
tokens=model_dir + "/" + tokens,
model_type=transducer_type,
debug=False,
provider=provider,
num_threads=num_threads
)
else:
recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
encoder=model_dir + "/" + encoder,
decoder=model_dir + "/" + decoder,
joiner=model_dir + "/" + joiner,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "nemo_ctc":
model_list = []
for file in file_list:
if file.endswith(".onnx"):
model_list.append(file)
elif file.endswith(".txt") and "token" in file:
tokens = file
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
if isOffline:
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
else:
recognizer = sherpa_onnx.OnlineRecognizer.from_nemo_ctc(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "nemo_canary":
encoder_list, decoder_list = [], []
tokens = ""
for file in file_list:
if "encode" in file and file.endswith(".onnx"):
encoder_list.append(file)
elif "decode" in file and file.endswith(".onnx"):
decoder_list.append(file)
elif "token" in file and file.endswith(".txt"):
tokens = file
encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
decoder = sorted(decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_canary(
encoder=model_dir + "/" + encoder,
decoder=model_dir + "/" + decoder,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "moonshine":
preprocessor_list, encoder_list, cached_decoder_list, uncached_decoder_list = [], [], [], []
tokens = ""
for file in file_list:
if "preprocess" in file and file.endswith(".onnx"):
preprocessor_list.append(file)
elif "encode" in file and file.endswith(".onnx"):
encoder_list.append(file)
elif "uncached_decode" in file and file.endswith(".onnx"):
uncached_decoder_list.append(file)
elif "cached_decode" in file and file.endswith(".onnx"):
cached_decoder_list.append(file)
elif "token" in file and file.endswith(".txt"):
tokens = file
preprocessor = sorted(preprocessor_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
cached_decoder = sorted(cached_decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
uncached_decoder = sorted(uncached_decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine(
preprocessor=model_dir + "/" + preprocessor,
encoder=model_dir + "/" + encoder,
cached_decoder=model_dir + "/" + cached_decoder,
uncached_decoder=model_dir + "/" + uncached_decoder,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "tdnn_ctc":
model_list = []
for file in file_list:
if file.endswith(".onnx"):
model_list.append(file)
elif file.endswith(".txt") and "token" in file:
tokens = file
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
else:
raise RuntimeError("Cannot recognize model_type")
except Exception as e:
raise RuntimeError(f"Failed to initial cuda model: {e}")
if warmup:
print("Start warmup...", flush=True)
stream = recognizer.create_stream()
audio, sample_rate = sf.read("warmup.wav", dtype="float32", always_2d=True)
stream.accept_waveform(sample_rate, audio)
recognizer.decode_stream(stream)
print("warmup complete.", flush=True)
status = "Success"
def test_sherpa(wavefile):
isOffline = app.state.config.get("offline_model", True)
audio, sample_rate = sf.read(wavefile, dtype="float32", always_2d=True)
audio = audio[:, 0]
generated_text = ""
start_t = datetime.datetime.now()
if isOffline:
# OfflineRecognizer非流式模型推理
if model_type in ["sensevoice"]:
stream = recognizer.create_stream()
stream.accept_waveform(sample_rate, audio)
recognizer.decode_stream(stream)
generated_text = stream.result.text
else:
# offline-asr model 大多对长音频支持不佳模型训练音频不长以及导出onnx结构中对一些中间态维度可能有上限
# 哪怕原版CPU推理中间可能都会崩溃采取小段切分形式测试
start_index = 0
internal = int(sample_rate * 29)
while start_index < len(audio):
stream = recognizer.create_stream()
stream.accept_waveform(sample_rate, audio[start_index:start_index + internal])
recognizer.decode_stream(stream)
generated_text += stream.result.text
start_index += internal
else:
# OnlineRecognizer流式模型推理,统一每一次只投喂2s音频数据
stream = recognizer.create_stream()
start_index = 0
chunk_size = int(sample_rate * 2)
while start_index < len(audio):
chunk = audio[start_index:start_index + chunk_size]
stream.accept_waveform(sample_rate, chunk)
while recognizer.is_ready(stream):
recognizer.decode_stream(stream)
# mid_text = recognizer.get_result(stream)
# print("partial result: " + mid_text, flush=True)
start_index += chunk_size
while recognizer.is_ready(stream):
recognizer.decode_stream(stream)
generated_text = recognizer.get_result(stream)
end_t = datetime.datetime.now()
elapsed_seconds = (end_t - start_t).total_seconds()
duration = audio.shape[-1] / sample_rate
rtf = elapsed_seconds / duration
print("Text:", generated_text, flush=True)
print(f"Audio duration:\t{duration:.3f} s", flush=True)
print(f"Elapsed:\t{elapsed_seconds:.3f} s", flush=True)
print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}", flush=True)
return generated_text
@app.get("/health")
def health():
if status=="Running":
return {
"status":"loading model"
}
ret = {
"status": "ok" if status == "Success" else "failed",
}
return ret
@app.post("/transduce")
def transduce(
audio: UploadFile = File(...),
lang: str = Form("zh"),
background_tasks: BackgroundTasks = None
):
try:
file_path = f"./input/{uuid.uuid4()}.wav"
with open(file_path, "wb") as f:
f.write(audio.file.read())
background_tasks.add_task(os.remove, file_path)
generated_text = test_sherpa(file_path)
return {"generated_text": generated_text}
except Exception:
raise HTTPException(status_code=500, detail=f"Processing failed: \n{traceback.format_exc()}")
# if __name__ == "__main__":
# uvicorn.run("fastapi_sherpa:app", host="0.0.0.0", port=1111, workers=1)

View File

@@ -0,0 +1,33 @@
import argparse
import uvicorn
from fastapi_sherpa import app
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", type=str, default="/model", required=True, help="model directory")
parser.add_argument("--model_type", type=str, default=None, help="model type, e.g. sensevoice")
parser.add_argument("--use_gpu", action="store_true", default=True)
parser.add_argument("--warmup", action="store_true", help="whether do warmup when first initializing model")
parser.add_argument("--model_name", type=str, default=None, help="model's full name(optional) to determine model type")
parser.add_argument("--num_threads", type=int, default=2, help="number of threads with model inference")
parser.add_argument("--offline_model", action="store_true", help="indicating a non-streaming model when this flag is set")
parser.add_argument("--port", type=int, default=8000, help="service port")
args = parser.parse_args()
# 将参数加到 app.state 中
app.state.config = {
"model_dir": args.model_dir,
"model_type": args.model_type,
"model_name": args.model_name,
"num_threads": args.num_threads,
"offline_model": args.offline_model,
"use_gpu": args.use_gpu, # True
"warmup": args.warmup,
}
uvicorn.run("fastapi_sherpa:app",
host="0.0.0.0",
port=args.port,
workers=1
)

View File

@@ -0,0 +1,14 @@
requests
wheel
websocket-client
pydantic>=2.0.0
numpy<2.0
PYYaml
Levenshtein
ruamel.yaml
nltk==3.7
pynini==2.1.6
soundfile
fastapi
uvicorn
python-multipart

BIN
sherpa-onnx/warmup.wav Normal file

Binary file not shown.

View File

@@ -0,0 +1,23 @@
FROM corex:4.3.8
WORKDIR /root
RUN set -eux; \
# 1) 把 aliyun 源替换成官方源(避免 403
sed -i -E 's|http://mirrors\.aliyun\.com/ubuntu|http://archive.ubuntu.com/ubuntu|g' /etc/apt/sources.list; \
sed -i -E 's|http://mirrors\.aliyun\.com/ubuntu|http://archive.ubuntu.com/ubuntu|g' /etc/apt/sources.list.d/*.list 2>/dev/null || true; \
\
# 2) 更新并安装
apt-get update; \
apt-get install -y --no-install-recommends vim net-tools ca-certificates libasound2-dev patchelf; \
rm -rf /var/lib/apt/lists/*
ADD . /root/
COPY requirements.txt /root
RUN pip install -r requirements.txt -i https://nexus.4pd.io/repository/pypi-all/simple --extra-index-url https://mirror.sjtu.edu.cn/pypi/web/simple
RUN pip install transformers==4.51.3 -i https://nexus.4pd.io/repository/pypi-all/simple --extra-index-url https://mirror.sjtu.edu.cn/pypi/web/simple
ENTRYPOINT ["python3"]
CMD ["./main_transformers.py"]

28
transformers/README.md Normal file
View File

@@ -0,0 +1,28 @@
# 天数智芯 天垓150 ASRTransformers架构
## 镜像构造
```shell
docker build -f ./Dockerfile.transformers-bi150 -t <your_image> .
```
其中,基础镜像 corex:4.3.8 通过联系天数智芯智铠100厂商技术支持可获取
## 使用说明
### 使用 FastAPI 启动ASR服务
例如:
```shell
docker run -dit -v /usr/src:/usr/src -v /lib/modules:/lib/modules --device=/dev/iluvatar0:/dev/iluvatar0 \
-v /mnt/contest_ceph/leaderboard/modelHubXC/openai-mirror/whisper-small:/model \
--network=host <your_image> \
main_transformers.py --model_dir /model --use_gpu --port 1111
```
具体参数代码设定可参考代码文件
### 测试ASR服务
项目根路径`sample_data`目录下附带上了中文的测试音频和附带内容
```shell
curl -X POST http://localhost:1111/transduce \
-F "audio=@../sample_data/lei-jun-test.wav" \
-F "lang=zh"
```

View File

@@ -0,0 +1,232 @@
import os
import time
import uuid
import json
import inspect
import traceback
import numpy as np
import torch
import torchaudio
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Form
import uvicorn
from transformers import pipeline as hf_pipeline
os.makedirs("./input", exist_ok=True)
status = "Running"
asr_pipeline = None
is_whisper = False # 唯一需要区分的分支Whisper(seq2seq) vs 其余所有 CTC 类模型
device = ""
app = FastAPI()
CUSTOM_DEVICE = os.getenv("CUSTOM_DEVICE", "")
if CUSTOM_DEVICE.startswith("mlu"):
import torch_mlu
elif CUSTOM_DEVICE.startswith("ascend"):
import torch_npu
elif CUSTOM_DEVICE.startswith("pt"):
import torch_dipu
class _SamplingRateCompatProxy:
"""为非标准 FeatureExtractor 提供兼容性包装。
transformers pipeline 的 preprocess 固定会向 feature_extractor 传 sampling_rate、
return_tensors 等标准 kwargs但部分模型如 GraniteSpeech的 FeatureExtractor
没有实现这些参数。此代理在初始化时检查签名,调用时只转发 FeatureExtractor 实际接受的参数。
调用前须确保音频已按模型期望采样率重采样完毕run_asr 中已完成)。
"""
def __init__(self, fe):
object.__setattr__(self, "_fe", fe)
# 初始化时检查一次签名,确定接受哪些参数
try:
sig = inspect.signature(fe.__call__)
has_var_kw = any(
p.kind == inspect.Parameter.VAR_KEYWORD
for p in sig.parameters.values()
)
accepted = None if has_var_kw else set(sig.parameters.keys()) - {"self"}
except Exception:
accepted = None # 无法检测时不过滤
object.__setattr__(self, "_accepted", accepted)
def __call__(self, *args, **kwargs):
accepted = object.__getattribute__(self, "_accepted")
if accepted is not None:
kwargs = {k: v for k, v in kwargs.items() if k in accepted}
return object.__getattribute__(self, "_fe")(*args, **kwargs)
def __getattr__(self, name):
return getattr(object.__getattribute__(self, "_fe"), name)
def __setattr__(self, name, value):
setattr(object.__getattribute__(self, "_fe"), name, value)
def _check_is_whisper(model_dir: str, model_type_override: str = None) -> bool:
"""判断是否为 Whisper 架构。
优先使用用户显式传入的 model_type_override
否则读 config.json 中的 model_type 字段(所有 whisper fine-tuned 模型均有此字段)。
"""
if model_type_override:
return model_type_override.lower() == "whisper"
config_path = os.path.join(model_dir, "config.json")
if os.path.exists(config_path):
with open(config_path, "r") as f:
cfg = json.load(f)
return cfg.get("model_type", "").lower() == "whisper"
# config.json 不存在时,从目录名做最后兜底
return "whisper" in os.path.basename(model_dir).lower()
@app.on_event("startup")
def load_model():
global status, asr_pipeline, is_whisper, device
config = app.state.config
use_gpu = config.get("use_gpu", True)
model_dir = config.get("model_dir", "/model")
model_type_override = config.get("model_type", None) # 可选,仅用于覆盖自动判断
warmup = config.get("warmup", False)
use_fp16 = config.get("fp16", False) # 默认 fp32需要用户显式开启
# 与 fastapi_funasr.py 保持一致的设备字符串逻辑,直接传字符串给 pipeline
device = "cpu"
if use_gpu:
if CUSTOM_DEVICE.startswith("mlu"):
device = "mlu:0"
elif CUSTOM_DEVICE.startswith("ascend"):
device = "npu:0"
else:
device = "cuda:0"
is_whisper = _check_is_whisper(model_dir, model_type_override)
# 默认 fp32跨平台兼容性最好且不影响精度对比
# fp16 需要用户显式开启(--fp16且应确认当前硬件支持
torch_dtype = torch.float16 if use_fp16 else torch.float32
print(">> Startup config:")
print(" model_dir =", model_dir, flush=True)
print(" is_whisper =", is_whisper, flush=True)
print(" device =", device, flush=True)
print(" torch_dtype =", torch_dtype, flush=True)
print(" chunk_length_s =", app.state.config.get("chunk_length_s", 30), flush=True)
print(" warmup =", warmup, flush=True)
# transformers pipeline 直接接受设备字符串("cpu"/"cuda:0"/"mlu:0"/"npu:0"
# 会自动读取 config.json 实例化正确的模型类,无需手动指定架构
# 注意:不在 pipeline 构建时传 chunk_length_s由 run_asr 自行分片后逐段调用
# 原因:部分模型(如 GraniteSpeech的 FeatureExtractor 不接受 sampling_rate 参数,
# 而 pipeline 内部的 chunk_iter 固定会传该参数,导致报错
asr_pipeline = hf_pipeline(
task="automatic-speech-recognition",
model=model_dir,
device=device,
torch_dtype=torch_dtype,
)
# 检查 feature extractor 是否接受 sampling_rate 参数
# pipeline 的 preprocess 固定会传此参数(硬编码行为),不接受的模型需要代理包装
try:
sig = inspect.signature(asr_pipeline.feature_extractor.__call__)
accepts_sr = "sampling_rate" in sig.parameters or any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
)
except Exception:
accepts_sr = True # 无法检测时保守假设接受
if not accepts_sr:
asr_pipeline.feature_extractor = _SamplingRateCompatProxy(asr_pipeline.feature_extractor)
print(" Note: FeatureExtractor does not accept sampling_rate, applied compat proxy", flush=True)
if warmup:
print("Start warmup...", flush=True)
# 获取模型期望的采样率,绝大多数模型的 feature extractor 都有此属性
# 极少数非标准模型可能没有,兜底用 16000ASR 领域最通用的标准采样率)
target_sr = getattr(asr_pipeline.feature_extractor, "sampling_rate", 16000)
dummy = np.zeros(target_sr, dtype=np.float32) # 1 秒静音
asr_pipeline(dummy, **_build_infer_kwargs("zh"))
print("warmup complete.", flush=True)
status = "Success"
def _build_infer_kwargs(lang: str) -> dict:
"""Whisper 推理时需要额外传语言参数CTC 类无需额外参数。
不再传 return_timestamps因为我们自行分片后逐段调用 pipeline无需 pipeline 内部拼接。
"""
if is_whisper:
return {"generate_kwargs": {"language": lang, "task": "transcribe"}}
return {}
def run_asr(audio_file: str, lang: str) -> str:
waveform, sample_rate = torchaudio.load(audio_file)
duration = waveform.shape[1] / sample_rate
# 多声道转单声道
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# 提前重采样到模型期望的采样率
# 传 numpy array非 dict给 pipeline跳过 pipeline 内部的 sampling_rate 传递逻辑,
# 规避部分模型(如 GraniteSpeech的 FeatureExtractor 不接受 sampling_rate 参数的问题
# 获取模型期望的采样率,绝大多数模型的 feature extractor 都有此属性
# 极少数非标准模型可能没有,兜底用 16000ASR 领域最通用的标准采样率)
target_sr = getattr(asr_pipeline.feature_extractor, "sampling_rate", 16000)
if sample_rate != target_sr:
resampler = torchaudio.transforms.Resample(sample_rate, target_sr)
waveform = resampler(waveform)
audio_array = waveform.squeeze(0).numpy().astype(np.float32)
chunk_length_s = app.state.config.get("chunk_length_s", 30)
chunk_samples = chunk_length_s * target_sr
infer_kwargs = _build_infer_kwargs(lang)
ts1 = time.time()
texts = []
for i in range(0, len(audio_array), chunk_samples):
chunk = audio_array[i : i + chunk_samples]
result = asr_pipeline(chunk, **infer_kwargs)
texts.append(result["text"])
ts2 = time.time()
generated_text = "".join(texts)
# wav2vec2 系列模型会用 U+2581 (▁) 作为词间分隔符,替换为空格
generated_text = generated_text.replace("", " ").replace(chr(9601), " ").strip()
processing_time = ts2 - ts1
rtf = processing_time / duration
print("Text:", generated_text, flush=True)
print(f"Audio duration:\t{duration:.3f} s", flush=True)
print(f"Elapsed:\t{processing_time:.3f} s", flush=True)
print(f"RTF = {processing_time:.3f}/{duration:.3f} = {rtf:.3f}", flush=True)
return generated_text
@app.get("/health")
def health():
if status == "Running":
return {"status": "loading model"}
return {"status": "ok" if status == "Success" else "failed"}
@app.post("/transduce")
def transduce(
audio: UploadFile = File(...),
lang: str = Form("zh"),
background_tasks: BackgroundTasks = None,
):
try:
file_path = f"./input/{uuid.uuid4()}.wav"
with open(file_path, "wb") as f:
f.write(audio.file.read())
background_tasks.add_task(os.remove, file_path)
generated_text = run_asr(file_path, lang)
return {"generated_text": generated_text}
except Exception:
raise HTTPException(status_code=500, detail=f"Processing failed: \n{traceback.format_exc()}")
# if __name__ == "__main__":
# uvicorn.run("fastapi_transformers:app", host="0.0.0.0", port=8000, workers=1)

View File

@@ -0,0 +1,38 @@
import argparse
import uvicorn
from fastapi_transformers import app
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", type=str, default="/model",
help="模型目录(挂载到容器内的路径)")
parser.add_argument("--model_type", type=str, default=None,
help="可选,仅在自动推断失败时手动指定: whisper 或不填CTC 类均不需要填)")
parser.add_argument("--use_gpu", action="store_true", default=True,
help="是否使用 GPUCUDA")
parser.add_argument("--warmup", action="store_true",
help="启动时用静音片段执行一次 warmup 推理")
parser.add_argument("--chunk_length_s", type=int, default=30,
help="长音频切片长度(秒),逐段推理,默认 30")
parser.add_argument("--fp16", action="store_true", default=False,
help="使用 float16 推理(默认 float32。仅在确认硬件支持时开启"
"注意 fp16/fp32 之间存在精度差异,跨卡对比时建议保持默认 fp32")
parser.add_argument("--port", type=int, default=8000,
help="FastAPI 服务端口,默认 8000")
args = parser.parse_args()
app.state.config = {
"model_dir": args.model_dir,
"model_type": args.model_type,
"use_gpu": args.use_gpu,
"warmup": args.warmup,
"chunk_length_s": args.chunk_length_s,
"fp16": args.fp16,
}
uvicorn.run("fastapi_transformers:app",
host="0.0.0.0",
port=args.port,
workers=1
)

View File

@@ -0,0 +1,14 @@
requests
wheel
websocket-client
pydantic>=2.0.0
numpy<2.0
PYYaml
Levenshtein
ruamel.yaml
nltk==3.7
pynini==2.1.6
soundfile
fastapi
uvicorn
python-multipart

BIN
transformers/warmup.wav Normal file

Binary file not shown.