initial commit

This commit is contained in:
2026-04-08 06:41:00 +00:00
commit 1385a6f46b
23 changed files with 2831 additions and 0 deletions

View File

@@ -0,0 +1,23 @@
FROM corex:4.3.8
WORKDIR /root
RUN set -eux; \
# 1) 把 aliyun 源替换成官方源(避免 403
sed -i -E 's|http://mirrors\.aliyun\.com/ubuntu|http://archive.ubuntu.com/ubuntu|g' /etc/apt/sources.list; \
sed -i -E 's|http://mirrors\.aliyun\.com/ubuntu|http://archive.ubuntu.com/ubuntu|g' /etc/apt/sources.list.d/*.list 2>/dev/null || true; \
\
# 2) 更新并安装
apt-get update; \
apt-get install -y --no-install-recommends vim net-tools ca-certificates libasound2-dev patchelf; \
rm -rf /var/lib/apt/lists/*
ADD . /root/
COPY requirements.txt /root
RUN pip install -r requirements.txt -i https://nexus.4pd.io/repository/pypi-all/simple --extra-index-url https://mirror.sjtu.edu.cn/pypi/web/simple
RUN pip install transformers==4.51.3 -i https://nexus.4pd.io/repository/pypi-all/simple --extra-index-url https://mirror.sjtu.edu.cn/pypi/web/simple
ENTRYPOINT ["python3"]
CMD ["./main_transformers.py"]

28
transformers/README.md Normal file
View File

@@ -0,0 +1,28 @@
# 天数智芯 天垓150 ASRTransformers架构
## 镜像构造
```shell
docker build -f ./Dockerfile.transformers-bi150 -t <your_image> .
```
其中,基础镜像 corex:4.3.8 通过联系天数智芯智铠100厂商技术支持可获取
## 使用说明
### 使用 FastAPI 启动ASR服务
例如:
```shell
docker run -dit -v /usr/src:/usr/src -v /lib/modules:/lib/modules --device=/dev/iluvatar0:/dev/iluvatar0 \
-v /mnt/contest_ceph/leaderboard/modelHubXC/openai-mirror/whisper-small:/model \
--network=host <your_image> \
main_transformers.py --model_dir /model --use_gpu --port 1111
```
具体参数代码设定可参考代码文件
### 测试ASR服务
项目根路径`sample_data`目录下附带上了中文的测试音频和附带内容
```shell
curl -X POST http://localhost:1111/transduce \
-F "audio=@../sample_data/lei-jun-test.wav" \
-F "lang=zh"
```

View File

@@ -0,0 +1,232 @@
import os
import time
import uuid
import json
import inspect
import traceback
import numpy as np
import torch
import torchaudio
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Form
import uvicorn
from transformers import pipeline as hf_pipeline
os.makedirs("./input", exist_ok=True)
status = "Running"
asr_pipeline = None
is_whisper = False # 唯一需要区分的分支Whisper(seq2seq) vs 其余所有 CTC 类模型
device = ""
app = FastAPI()
CUSTOM_DEVICE = os.getenv("CUSTOM_DEVICE", "")
if CUSTOM_DEVICE.startswith("mlu"):
import torch_mlu
elif CUSTOM_DEVICE.startswith("ascend"):
import torch_npu
elif CUSTOM_DEVICE.startswith("pt"):
import torch_dipu
class _SamplingRateCompatProxy:
"""为非标准 FeatureExtractor 提供兼容性包装。
transformers pipeline 的 preprocess 固定会向 feature_extractor 传 sampling_rate、
return_tensors 等标准 kwargs但部分模型如 GraniteSpeech的 FeatureExtractor
没有实现这些参数。此代理在初始化时检查签名,调用时只转发 FeatureExtractor 实际接受的参数。
调用前须确保音频已按模型期望采样率重采样完毕run_asr 中已完成)。
"""
def __init__(self, fe):
object.__setattr__(self, "_fe", fe)
# 初始化时检查一次签名,确定接受哪些参数
try:
sig = inspect.signature(fe.__call__)
has_var_kw = any(
p.kind == inspect.Parameter.VAR_KEYWORD
for p in sig.parameters.values()
)
accepted = None if has_var_kw else set(sig.parameters.keys()) - {"self"}
except Exception:
accepted = None # 无法检测时不过滤
object.__setattr__(self, "_accepted", accepted)
def __call__(self, *args, **kwargs):
accepted = object.__getattribute__(self, "_accepted")
if accepted is not None:
kwargs = {k: v for k, v in kwargs.items() if k in accepted}
return object.__getattribute__(self, "_fe")(*args, **kwargs)
def __getattr__(self, name):
return getattr(object.__getattribute__(self, "_fe"), name)
def __setattr__(self, name, value):
setattr(object.__getattribute__(self, "_fe"), name, value)
def _check_is_whisper(model_dir: str, model_type_override: str = None) -> bool:
"""判断是否为 Whisper 架构。
优先使用用户显式传入的 model_type_override
否则读 config.json 中的 model_type 字段(所有 whisper fine-tuned 模型均有此字段)。
"""
if model_type_override:
return model_type_override.lower() == "whisper"
config_path = os.path.join(model_dir, "config.json")
if os.path.exists(config_path):
with open(config_path, "r") as f:
cfg = json.load(f)
return cfg.get("model_type", "").lower() == "whisper"
# config.json 不存在时,从目录名做最后兜底
return "whisper" in os.path.basename(model_dir).lower()
@app.on_event("startup")
def load_model():
global status, asr_pipeline, is_whisper, device
config = app.state.config
use_gpu = config.get("use_gpu", True)
model_dir = config.get("model_dir", "/model")
model_type_override = config.get("model_type", None) # 可选,仅用于覆盖自动判断
warmup = config.get("warmup", False)
use_fp16 = config.get("fp16", False) # 默认 fp32需要用户显式开启
# 与 fastapi_funasr.py 保持一致的设备字符串逻辑,直接传字符串给 pipeline
device = "cpu"
if use_gpu:
if CUSTOM_DEVICE.startswith("mlu"):
device = "mlu:0"
elif CUSTOM_DEVICE.startswith("ascend"):
device = "npu:0"
else:
device = "cuda:0"
is_whisper = _check_is_whisper(model_dir, model_type_override)
# 默认 fp32跨平台兼容性最好且不影响精度对比
# fp16 需要用户显式开启(--fp16且应确认当前硬件支持
torch_dtype = torch.float16 if use_fp16 else torch.float32
print(">> Startup config:")
print(" model_dir =", model_dir, flush=True)
print(" is_whisper =", is_whisper, flush=True)
print(" device =", device, flush=True)
print(" torch_dtype =", torch_dtype, flush=True)
print(" chunk_length_s =", app.state.config.get("chunk_length_s", 30), flush=True)
print(" warmup =", warmup, flush=True)
# transformers pipeline 直接接受设备字符串("cpu"/"cuda:0"/"mlu:0"/"npu:0"
# 会自动读取 config.json 实例化正确的模型类,无需手动指定架构
# 注意:不在 pipeline 构建时传 chunk_length_s由 run_asr 自行分片后逐段调用
# 原因:部分模型(如 GraniteSpeech的 FeatureExtractor 不接受 sampling_rate 参数,
# 而 pipeline 内部的 chunk_iter 固定会传该参数,导致报错
asr_pipeline = hf_pipeline(
task="automatic-speech-recognition",
model=model_dir,
device=device,
torch_dtype=torch_dtype,
)
# 检查 feature extractor 是否接受 sampling_rate 参数
# pipeline 的 preprocess 固定会传此参数(硬编码行为),不接受的模型需要代理包装
try:
sig = inspect.signature(asr_pipeline.feature_extractor.__call__)
accepts_sr = "sampling_rate" in sig.parameters or any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
)
except Exception:
accepts_sr = True # 无法检测时保守假设接受
if not accepts_sr:
asr_pipeline.feature_extractor = _SamplingRateCompatProxy(asr_pipeline.feature_extractor)
print(" Note: FeatureExtractor does not accept sampling_rate, applied compat proxy", flush=True)
if warmup:
print("Start warmup...", flush=True)
# 获取模型期望的采样率,绝大多数模型的 feature extractor 都有此属性
# 极少数非标准模型可能没有,兜底用 16000ASR 领域最通用的标准采样率)
target_sr = getattr(asr_pipeline.feature_extractor, "sampling_rate", 16000)
dummy = np.zeros(target_sr, dtype=np.float32) # 1 秒静音
asr_pipeline(dummy, **_build_infer_kwargs("zh"))
print("warmup complete.", flush=True)
status = "Success"
def _build_infer_kwargs(lang: str) -> dict:
"""Whisper 推理时需要额外传语言参数CTC 类无需额外参数。
不再传 return_timestamps因为我们自行分片后逐段调用 pipeline无需 pipeline 内部拼接。
"""
if is_whisper:
return {"generate_kwargs": {"language": lang, "task": "transcribe"}}
return {}
def run_asr(audio_file: str, lang: str) -> str:
waveform, sample_rate = torchaudio.load(audio_file)
duration = waveform.shape[1] / sample_rate
# 多声道转单声道
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# 提前重采样到模型期望的采样率
# 传 numpy array非 dict给 pipeline跳过 pipeline 内部的 sampling_rate 传递逻辑,
# 规避部分模型(如 GraniteSpeech的 FeatureExtractor 不接受 sampling_rate 参数的问题
# 获取模型期望的采样率,绝大多数模型的 feature extractor 都有此属性
# 极少数非标准模型可能没有,兜底用 16000ASR 领域最通用的标准采样率)
target_sr = getattr(asr_pipeline.feature_extractor, "sampling_rate", 16000)
if sample_rate != target_sr:
resampler = torchaudio.transforms.Resample(sample_rate, target_sr)
waveform = resampler(waveform)
audio_array = waveform.squeeze(0).numpy().astype(np.float32)
chunk_length_s = app.state.config.get("chunk_length_s", 30)
chunk_samples = chunk_length_s * target_sr
infer_kwargs = _build_infer_kwargs(lang)
ts1 = time.time()
texts = []
for i in range(0, len(audio_array), chunk_samples):
chunk = audio_array[i : i + chunk_samples]
result = asr_pipeline(chunk, **infer_kwargs)
texts.append(result["text"])
ts2 = time.time()
generated_text = "".join(texts)
# wav2vec2 系列模型会用 U+2581 (▁) 作为词间分隔符,替换为空格
generated_text = generated_text.replace("", " ").replace(chr(9601), " ").strip()
processing_time = ts2 - ts1
rtf = processing_time / duration
print("Text:", generated_text, flush=True)
print(f"Audio duration:\t{duration:.3f} s", flush=True)
print(f"Elapsed:\t{processing_time:.3f} s", flush=True)
print(f"RTF = {processing_time:.3f}/{duration:.3f} = {rtf:.3f}", flush=True)
return generated_text
@app.get("/health")
def health():
if status == "Running":
return {"status": "loading model"}
return {"status": "ok" if status == "Success" else "failed"}
@app.post("/transduce")
def transduce(
audio: UploadFile = File(...),
lang: str = Form("zh"),
background_tasks: BackgroundTasks = None,
):
try:
file_path = f"./input/{uuid.uuid4()}.wav"
with open(file_path, "wb") as f:
f.write(audio.file.read())
background_tasks.add_task(os.remove, file_path)
generated_text = run_asr(file_path, lang)
return {"generated_text": generated_text}
except Exception:
raise HTTPException(status_code=500, detail=f"Processing failed: \n{traceback.format_exc()}")
# if __name__ == "__main__":
# uvicorn.run("fastapi_transformers:app", host="0.0.0.0", port=8000, workers=1)

View File

@@ -0,0 +1,38 @@
import argparse
import uvicorn
from fastapi_transformers import app
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", type=str, default="/model",
help="模型目录(挂载到容器内的路径)")
parser.add_argument("--model_type", type=str, default=None,
help="可选,仅在自动推断失败时手动指定: whisper 或不填CTC 类均不需要填)")
parser.add_argument("--use_gpu", action="store_true", default=True,
help="是否使用 GPUCUDA")
parser.add_argument("--warmup", action="store_true",
help="启动时用静音片段执行一次 warmup 推理")
parser.add_argument("--chunk_length_s", type=int, default=30,
help="长音频切片长度(秒),逐段推理,默认 30")
parser.add_argument("--fp16", action="store_true", default=False,
help="使用 float16 推理(默认 float32。仅在确认硬件支持时开启"
"注意 fp16/fp32 之间存在精度差异,跨卡对比时建议保持默认 fp32")
parser.add_argument("--port", type=int, default=8000,
help="FastAPI 服务端口,默认 8000")
args = parser.parse_args()
app.state.config = {
"model_dir": args.model_dir,
"model_type": args.model_type,
"use_gpu": args.use_gpu,
"warmup": args.warmup,
"chunk_length_s": args.chunk_length_s,
"fp16": args.fp16,
}
uvicorn.run("fastapi_transformers:app",
host="0.0.0.0",
port=args.port,
workers=1
)

View File

@@ -0,0 +1,14 @@
requests
wheel
websocket-client
pydantic>=2.0.0
numpy<2.0
PYYaml
Levenshtein
ruamel.yaml
nltk==3.7
pynini==2.1.6
soundfile
fastapi
uvicorn
python-multipart

BIN
transformers/warmup.wav Normal file

Binary file not shown.