initial commit
This commit is contained in:
23
transformers/Dockerfile.transformers-bi150
Normal file
23
transformers/Dockerfile.transformers-bi150
Normal file
@@ -0,0 +1,23 @@
|
||||
FROM corex:4.3.8
|
||||
|
||||
WORKDIR /root
|
||||
|
||||
RUN set -eux; \
|
||||
# 1) 把 aliyun 源替换成官方源(避免 403)
|
||||
sed -i -E 's|http://mirrors\.aliyun\.com/ubuntu|http://archive.ubuntu.com/ubuntu|g' /etc/apt/sources.list; \
|
||||
sed -i -E 's|http://mirrors\.aliyun\.com/ubuntu|http://archive.ubuntu.com/ubuntu|g' /etc/apt/sources.list.d/*.list 2>/dev/null || true; \
|
||||
\
|
||||
# 2) 更新并安装
|
||||
apt-get update; \
|
||||
apt-get install -y --no-install-recommends vim net-tools ca-certificates libasound2-dev patchelf; \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
ADD . /root/
|
||||
|
||||
COPY requirements.txt /root
|
||||
RUN pip install -r requirements.txt -i https://nexus.4pd.io/repository/pypi-all/simple --extra-index-url https://mirror.sjtu.edu.cn/pypi/web/simple
|
||||
|
||||
RUN pip install transformers==4.51.3 -i https://nexus.4pd.io/repository/pypi-all/simple --extra-index-url https://mirror.sjtu.edu.cn/pypi/web/simple
|
||||
|
||||
ENTRYPOINT ["python3"]
|
||||
CMD ["./main_transformers.py"]
|
||||
28
transformers/README.md
Normal file
28
transformers/README.md
Normal file
@@ -0,0 +1,28 @@
|
||||
# 天数智芯 天垓150 ASR(Transformers架构)
|
||||
|
||||
## 镜像构造
|
||||
```shell
|
||||
docker build -f ./Dockerfile.transformers-bi150 -t <your_image> .
|
||||
```
|
||||
其中,基础镜像 corex:4.3.8 通过联系天数智芯智铠100厂商技术支持可获取
|
||||
|
||||
## 使用说明
|
||||
|
||||
### 使用 FastAPI 启动ASR服务:
|
||||
例如:
|
||||
```shell
|
||||
docker run -dit -v /usr/src:/usr/src -v /lib/modules:/lib/modules --device=/dev/iluvatar0:/dev/iluvatar0 \
|
||||
-v /mnt/contest_ceph/leaderboard/modelHubXC/openai-mirror/whisper-small:/model \
|
||||
--network=host <your_image> \
|
||||
main_transformers.py --model_dir /model --use_gpu --port 1111
|
||||
```
|
||||
具体参数代码设定可参考代码文件
|
||||
|
||||
### 测试ASR服务
|
||||
项目根路径`sample_data`目录下附带上了中文的测试音频和附带内容
|
||||
|
||||
```shell
|
||||
curl -X POST http://localhost:1111/transduce \
|
||||
-F "audio=@../sample_data/lei-jun-test.wav" \
|
||||
-F "lang=zh"
|
||||
```
|
||||
232
transformers/fastapi_transformers.py
Normal file
232
transformers/fastapi_transformers.py
Normal file
@@ -0,0 +1,232 @@
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
import json
|
||||
import inspect
|
||||
import traceback
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Form
|
||||
import uvicorn
|
||||
from transformers import pipeline as hf_pipeline
|
||||
|
||||
os.makedirs("./input", exist_ok=True)
|
||||
status = "Running"
|
||||
asr_pipeline = None
|
||||
is_whisper = False # 唯一需要区分的分支:Whisper(seq2seq) vs 其余所有 CTC 类模型
|
||||
device = ""
|
||||
app = FastAPI()
|
||||
|
||||
CUSTOM_DEVICE = os.getenv("CUSTOM_DEVICE", "")
|
||||
if CUSTOM_DEVICE.startswith("mlu"):
|
||||
import torch_mlu
|
||||
elif CUSTOM_DEVICE.startswith("ascend"):
|
||||
import torch_npu
|
||||
elif CUSTOM_DEVICE.startswith("pt"):
|
||||
import torch_dipu
|
||||
|
||||
|
||||
class _SamplingRateCompatProxy:
|
||||
"""为非标准 FeatureExtractor 提供兼容性包装。
|
||||
|
||||
transformers pipeline 的 preprocess 固定会向 feature_extractor 传 sampling_rate、
|
||||
return_tensors 等标准 kwargs,但部分模型(如 GraniteSpeech)的 FeatureExtractor
|
||||
没有实现这些参数。此代理在初始化时检查签名,调用时只转发 FeatureExtractor 实际接受的参数。
|
||||
调用前须确保音频已按模型期望采样率重采样完毕(run_asr 中已完成)。
|
||||
"""
|
||||
def __init__(self, fe):
|
||||
object.__setattr__(self, "_fe", fe)
|
||||
# 初始化时检查一次签名,确定接受哪些参数
|
||||
try:
|
||||
sig = inspect.signature(fe.__call__)
|
||||
has_var_kw = any(
|
||||
p.kind == inspect.Parameter.VAR_KEYWORD
|
||||
for p in sig.parameters.values()
|
||||
)
|
||||
accepted = None if has_var_kw else set(sig.parameters.keys()) - {"self"}
|
||||
except Exception:
|
||||
accepted = None # 无法检测时不过滤
|
||||
object.__setattr__(self, "_accepted", accepted)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
accepted = object.__getattribute__(self, "_accepted")
|
||||
if accepted is not None:
|
||||
kwargs = {k: v for k, v in kwargs.items() if k in accepted}
|
||||
return object.__getattribute__(self, "_fe")(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(object.__getattribute__(self, "_fe"), name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
setattr(object.__getattribute__(self, "_fe"), name, value)
|
||||
|
||||
|
||||
def _check_is_whisper(model_dir: str, model_type_override: str = None) -> bool:
|
||||
"""判断是否为 Whisper 架构。
|
||||
优先使用用户显式传入的 model_type_override,
|
||||
否则读 config.json 中的 model_type 字段(所有 whisper fine-tuned 模型均有此字段)。
|
||||
"""
|
||||
if model_type_override:
|
||||
return model_type_override.lower() == "whisper"
|
||||
config_path = os.path.join(model_dir, "config.json")
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, "r") as f:
|
||||
cfg = json.load(f)
|
||||
return cfg.get("model_type", "").lower() == "whisper"
|
||||
# config.json 不存在时,从目录名做最后兜底
|
||||
return "whisper" in os.path.basename(model_dir).lower()
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def load_model():
|
||||
global status, asr_pipeline, is_whisper, device
|
||||
|
||||
config = app.state.config
|
||||
use_gpu = config.get("use_gpu", True)
|
||||
model_dir = config.get("model_dir", "/model")
|
||||
model_type_override = config.get("model_type", None) # 可选,仅用于覆盖自动判断
|
||||
warmup = config.get("warmup", False)
|
||||
use_fp16 = config.get("fp16", False) # 默认 fp32,需要用户显式开启
|
||||
|
||||
# 与 fastapi_funasr.py 保持一致的设备字符串逻辑,直接传字符串给 pipeline
|
||||
device = "cpu"
|
||||
if use_gpu:
|
||||
if CUSTOM_DEVICE.startswith("mlu"):
|
||||
device = "mlu:0"
|
||||
elif CUSTOM_DEVICE.startswith("ascend"):
|
||||
device = "npu:0"
|
||||
else:
|
||||
device = "cuda:0"
|
||||
|
||||
is_whisper = _check_is_whisper(model_dir, model_type_override)
|
||||
# 默认 fp32,跨平台兼容性最好且不影响精度对比
|
||||
# fp16 需要用户显式开启(--fp16),且应确认当前硬件支持
|
||||
torch_dtype = torch.float16 if use_fp16 else torch.float32
|
||||
|
||||
print(">> Startup config:")
|
||||
print(" model_dir =", model_dir, flush=True)
|
||||
print(" is_whisper =", is_whisper, flush=True)
|
||||
print(" device =", device, flush=True)
|
||||
print(" torch_dtype =", torch_dtype, flush=True)
|
||||
print(" chunk_length_s =", app.state.config.get("chunk_length_s", 30), flush=True)
|
||||
print(" warmup =", warmup, flush=True)
|
||||
|
||||
# transformers pipeline 直接接受设备字符串("cpu"/"cuda:0"/"mlu:0"/"npu:0")
|
||||
# 会自动读取 config.json 实例化正确的模型类,无需手动指定架构
|
||||
# 注意:不在 pipeline 构建时传 chunk_length_s,由 run_asr 自行分片后逐段调用
|
||||
# 原因:部分模型(如 GraniteSpeech)的 FeatureExtractor 不接受 sampling_rate 参数,
|
||||
# 而 pipeline 内部的 chunk_iter 固定会传该参数,导致报错
|
||||
asr_pipeline = hf_pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model=model_dir,
|
||||
device=device,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
|
||||
# 检查 feature extractor 是否接受 sampling_rate 参数
|
||||
# pipeline 的 preprocess 固定会传此参数(硬编码行为),不接受的模型需要代理包装
|
||||
try:
|
||||
sig = inspect.signature(asr_pipeline.feature_extractor.__call__)
|
||||
accepts_sr = "sampling_rate" in sig.parameters or any(
|
||||
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
|
||||
)
|
||||
except Exception:
|
||||
accepts_sr = True # 无法检测时保守假设接受
|
||||
if not accepts_sr:
|
||||
asr_pipeline.feature_extractor = _SamplingRateCompatProxy(asr_pipeline.feature_extractor)
|
||||
print(" Note: FeatureExtractor does not accept sampling_rate, applied compat proxy", flush=True)
|
||||
|
||||
if warmup:
|
||||
print("Start warmup...", flush=True)
|
||||
# 获取模型期望的采样率,绝大多数模型的 feature extractor 都有此属性
|
||||
# 极少数非标准模型可能没有,兜底用 16000(ASR 领域最通用的标准采样率)
|
||||
target_sr = getattr(asr_pipeline.feature_extractor, "sampling_rate", 16000)
|
||||
dummy = np.zeros(target_sr, dtype=np.float32) # 1 秒静音
|
||||
asr_pipeline(dummy, **_build_infer_kwargs("zh"))
|
||||
print("warmup complete.", flush=True)
|
||||
|
||||
status = "Success"
|
||||
|
||||
|
||||
def _build_infer_kwargs(lang: str) -> dict:
|
||||
"""Whisper 推理时需要额外传语言参数;CTC 类无需额外参数。
|
||||
不再传 return_timestamps,因为我们自行分片后逐段调用 pipeline,无需 pipeline 内部拼接。
|
||||
"""
|
||||
if is_whisper:
|
||||
return {"generate_kwargs": {"language": lang, "task": "transcribe"}}
|
||||
return {}
|
||||
|
||||
|
||||
def run_asr(audio_file: str, lang: str) -> str:
|
||||
waveform, sample_rate = torchaudio.load(audio_file)
|
||||
duration = waveform.shape[1] / sample_rate
|
||||
|
||||
# 多声道转单声道
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = waveform.mean(dim=0, keepdim=True)
|
||||
|
||||
# 提前重采样到模型期望的采样率
|
||||
# 传 numpy array(非 dict)给 pipeline,跳过 pipeline 内部的 sampling_rate 传递逻辑,
|
||||
# 规避部分模型(如 GraniteSpeech)的 FeatureExtractor 不接受 sampling_rate 参数的问题
|
||||
# 获取模型期望的采样率,绝大多数模型的 feature extractor 都有此属性
|
||||
# 极少数非标准模型可能没有,兜底用 16000(ASR 领域最通用的标准采样率)
|
||||
target_sr = getattr(asr_pipeline.feature_extractor, "sampling_rate", 16000)
|
||||
if sample_rate != target_sr:
|
||||
resampler = torchaudio.transforms.Resample(sample_rate, target_sr)
|
||||
waveform = resampler(waveform)
|
||||
|
||||
audio_array = waveform.squeeze(0).numpy().astype(np.float32)
|
||||
|
||||
chunk_length_s = app.state.config.get("chunk_length_s", 30)
|
||||
chunk_samples = chunk_length_s * target_sr
|
||||
infer_kwargs = _build_infer_kwargs(lang)
|
||||
|
||||
ts1 = time.time()
|
||||
texts = []
|
||||
for i in range(0, len(audio_array), chunk_samples):
|
||||
chunk = audio_array[i : i + chunk_samples]
|
||||
result = asr_pipeline(chunk, **infer_kwargs)
|
||||
texts.append(result["text"])
|
||||
ts2 = time.time()
|
||||
|
||||
generated_text = "".join(texts)
|
||||
# wav2vec2 系列模型会用 U+2581 (▁) 作为词间分隔符,替换为空格
|
||||
generated_text = generated_text.replace("▁", " ").replace(chr(9601), " ").strip()
|
||||
|
||||
processing_time = ts2 - ts1
|
||||
rtf = processing_time / duration
|
||||
print("Text:", generated_text, flush=True)
|
||||
print(f"Audio duration:\t{duration:.3f} s", flush=True)
|
||||
print(f"Elapsed:\t{processing_time:.3f} s", flush=True)
|
||||
print(f"RTF = {processing_time:.3f}/{duration:.3f} = {rtf:.3f}", flush=True)
|
||||
|
||||
return generated_text
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
if status == "Running":
|
||||
return {"status": "loading model"}
|
||||
return {"status": "ok" if status == "Success" else "failed"}
|
||||
|
||||
|
||||
@app.post("/transduce")
|
||||
def transduce(
|
||||
audio: UploadFile = File(...),
|
||||
lang: str = Form("zh"),
|
||||
background_tasks: BackgroundTasks = None,
|
||||
):
|
||||
try:
|
||||
file_path = f"./input/{uuid.uuid4()}.wav"
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(audio.file.read())
|
||||
background_tasks.add_task(os.remove, file_path)
|
||||
generated_text = run_asr(file_path, lang)
|
||||
return {"generated_text": generated_text}
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail=f"Processing failed: \n{traceback.format_exc()}")
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# uvicorn.run("fastapi_transformers:app", host="0.0.0.0", port=8000, workers=1)
|
||||
38
transformers/main_transformers.py
Normal file
38
transformers/main_transformers.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import argparse
|
||||
import uvicorn
|
||||
from fastapi_transformers import app
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_dir", type=str, default="/model",
|
||||
help="模型目录(挂载到容器内的路径)")
|
||||
parser.add_argument("--model_type", type=str, default=None,
|
||||
help="可选,仅在自动推断失败时手动指定: whisper 或不填(CTC 类均不需要填)")
|
||||
parser.add_argument("--use_gpu", action="store_true", default=True,
|
||||
help="是否使用 GPU(CUDA)")
|
||||
parser.add_argument("--warmup", action="store_true",
|
||||
help="启动时用静音片段执行一次 warmup 推理")
|
||||
parser.add_argument("--chunk_length_s", type=int, default=30,
|
||||
help="长音频切片长度(秒),逐段推理,默认 30")
|
||||
parser.add_argument("--fp16", action="store_true", default=False,
|
||||
help="使用 float16 推理(默认 float32)。仅在确认硬件支持时开启,"
|
||||
"注意 fp16/fp32 之间存在精度差异,跨卡对比时建议保持默认 fp32")
|
||||
parser.add_argument("--port", type=int, default=8000,
|
||||
help="FastAPI 服务端口,默认 8000")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
app.state.config = {
|
||||
"model_dir": args.model_dir,
|
||||
"model_type": args.model_type,
|
||||
"use_gpu": args.use_gpu,
|
||||
"warmup": args.warmup,
|
||||
"chunk_length_s": args.chunk_length_s,
|
||||
"fp16": args.fp16,
|
||||
}
|
||||
|
||||
uvicorn.run("fastapi_transformers:app",
|
||||
host="0.0.0.0",
|
||||
port=args.port,
|
||||
workers=1
|
||||
)
|
||||
14
transformers/requirements.txt
Normal file
14
transformers/requirements.txt
Normal file
@@ -0,0 +1,14 @@
|
||||
requests
|
||||
wheel
|
||||
websocket-client
|
||||
pydantic>=2.0.0
|
||||
numpy<2.0
|
||||
PYYaml
|
||||
Levenshtein
|
||||
ruamel.yaml
|
||||
nltk==3.7
|
||||
pynini==2.1.6
|
||||
soundfile
|
||||
fastapi
|
||||
uvicorn
|
||||
python-multipart
|
||||
BIN
transformers/warmup.wav
Normal file
BIN
transformers/warmup.wav
Normal file
Binary file not shown.
Reference in New Issue
Block a user