From 189f9991524880da3510a3af7e52179eafdfd601 Mon Sep 17 00:00:00 2001 From: Lu Xinlong Date: Thu, 16 Apr 2026 10:54:32 +0800 Subject: [PATCH] initial commit --- .dockerignore | 1 + Dockerfile.qa_bi150 | 13 +++++ README.md | 39 ++++++++++++- fastapi_qa.py | 137 ++++++++++++++++++++++++++++++++++++++++++++ main_qa.py | 25 ++++++++ requirements.txt | 11 ++++ 6 files changed, 225 insertions(+), 1 deletion(-) create mode 100644 .dockerignore create mode 100644 Dockerfile.qa_bi150 create mode 100644 fastapi_qa.py create mode 100644 main_qa.py create mode 100644 requirements.txt diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..2a95e43 --- /dev/null +++ b/.dockerignore @@ -0,0 +1 @@ +test_scripts/ \ No newline at end of file diff --git a/Dockerfile.qa_bi150 b/Dockerfile.qa_bi150 new file mode 100644 index 0000000..2b6ec47 --- /dev/null +++ b/Dockerfile.qa_bi150 @@ -0,0 +1,13 @@ +FROM corex:4.3.8 + +WORKDIR /root + +ADD . /root/ + +COPY requirements.txt /root +RUN pip install -r requirements.txt -i https://nexus.4pd.io/repository/pypi-all/simple +# 安装torch是为了提供cuda库环境 +RUN pip install transformers==4.51.3 -i https://nexus.4pd.io/repository/pypi-all/simple + +ENTRYPOINT ["python3"] +CMD ["./main_qa.py"] \ No newline at end of file diff --git a/README.md b/README.md index 95b4e08..d410ec3 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,39 @@ -# enginex-bi_150-question-answering +# 天数智芯 天垓150 文本问答 + +## 镜像构造 +```shell +docker build -f ./Dockerfile.qa_bi150 -t . +``` +其中,基础镜像 corex:4.3.8 通过联系天数智芯智铠100厂商技术支持可获取 + +## 使用说明 + +### 使用 FastAPI 启动文本问答的服务: +例如: +```shell +docker run -dit -v /usr/src:/usr/src -v /lib/modules:/lib/modules --device=/dev/iluvatar0:/dev/iluvatar0 \ + -v /mnt/contest_ceph/leaderboard/modelHubXC/csarron/bert-base-uncased-squad-v1:/model \ + --network=host -e CONFIG_JSON='{ + "torch_dtype": "auto", + "handle_impossible_answer": false, + "score_threshold": 0.0, + "max_answer_len": 30, + "max_seq_len": 384, + "doc_stride": 128 + }' \ + --entrypoint=python3 \ + main_qa.py --model_dir /model --port 1111 +``` +具体参数代码设定可参考代码文件 + +### 测试服务 + +```shell +curl -X POST http://localhost:1111/qa \ + -H "Content-Type: application/json" \ + -d '{ + "context": "The capital city of China is Beijing", + "question": "What is the capital city of China?" + }' +``` diff --git a/fastapi_qa.py b/fastapi_qa.py new file mode 100644 index 0000000..24e4b7e --- /dev/null +++ b/fastapi_qa.py @@ -0,0 +1,137 @@ +import os +import json +import traceback +import torch + +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +from transformers import pipeline as hf_pipeline + +app = FastAPI() + +status = "Running" +qa_pipeline = None + +CUSTOM_DEVICE = os.getenv("CUSTOM_DEVICE", "") +if CUSTOM_DEVICE.startswith("mlu"): + import torch_mlu +elif CUSTOM_DEVICE.startswith("ascend"): + import torch_npu +elif CUSTOM_DEVICE.startswith("pt"): + import torch_dipu + +_DTYPE_MAP = { + "auto": "auto", + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, +} + + +def _parse_config_json() -> dict: + """从环境变量 CONFIG_JSON 读取可选配置,未设置时返回空字典。""" + raw = os.getenv("CONFIG_JSON", "").strip() + if not raw: + return {} + try: + return json.loads(raw) + except json.JSONDecodeError as e: + print(f"[WARN] CONFIG_JSON 解析失败,使用默认值: {e}", flush=True) + return {} + + +@app.on_event("startup") +def load_model(): + global status, qa_pipeline + + cfg = app.state.config # 来自 main_qa.py + extra = _parse_config_json() # 来自 CONFIG_JSON 环境变量 + + model_dir = cfg.get("model_dir", "/model") + use_gpu = cfg.get("use_gpu", True) + + # ---------- 设备 ---------- + device = "cpu" + if use_gpu: + if CUSTOM_DEVICE.startswith("mlu"): + device = "mlu:0" + elif CUSTOM_DEVICE.startswith("ascend"): + device = "npu:0" + else: + device = "cuda:0" + + # ---------- torch_dtype ---------- + dtype_str = extra.get("torch_dtype", "float32") + torch_dtype = _DTYPE_MAP.get(dtype_str, torch.float32) + + # ---------- pipeline 推理参数(透传给每次 __call__)---------- + # handle_impossible_answer: 支持 SQuAD 2.0 风格模型,预测无答案时返回 answer="" + app.state.handle_impossible_answer = extra.get("handle_impossible_answer", True) + app.state.score_threshold = float(extra.get("score_threshold", 0.0)) + app.state.max_answer_len = int(extra.get("max_answer_len", 15)) + app.state.max_seq_len = int(extra.get("max_seq_len", 384)) + app.state.doc_stride = int(extra.get("doc_stride", 128)) + + print(">> Startup config:", flush=True) + print(f" model_dir = {model_dir}", flush=True) + print(f" device = {device}", flush=True) + print(f" torch_dtype = {torch_dtype}", flush=True) + print(f" handle_impossible_answer= {app.state.handle_impossible_answer}", flush=True) + print(f" score_threshold = {app.state.score_threshold}", flush=True) + print(f" max_answer_len = {app.state.max_answer_len}", flush=True) + print(f" max_seq_len = {app.state.max_seq_len}", flush=True) + print(f" doc_stride = {app.state.doc_stride}", flush=True) + + qa_pipeline = hf_pipeline( + task="question-answering", + model=model_dir, + device=device, + torch_dtype=torch_dtype, + ) + + status = "Success" + print(">> Model loaded successfully.", flush=True) + + +class QARequest(BaseModel): + context: str + question: str + + +@app.get("/health") +def health(): + if status == "Running": + return {"status": "loading model"} + return {"status": "ok" if status == "Success" else "failed"} + + +@app.post("/qa") +def qa(req: QARequest): + if status != "Success": + raise HTTPException(status_code=503, detail="Model not ready") + try: + result = qa_pipeline( + question=req.question, + context=req.context, + handle_impossible_answer=app.state.handle_impossible_answer, + max_answer_len=app.state.max_answer_len, + max_seq_len=app.state.max_seq_len, + doc_stride=app.state.doc_stride, + ) + + answer = result.get("answer", "") + score = result.get("score", 0.0) + + # 两种情况视为无法回答: + # 1. 模型本身预测 no-answer(answer 为空串,handle_impossible_answer=True 时触发) + # 2. 置信度低于用户设定的 score_threshold + # 另外对于SQuad 1.1模型,问到反例就让他错因为模型没有处理反例能力一定会给出答案 + if not answer or (app.state.handle_impossible_answer and score < app.state.score_threshold): + answer = "" + + print(f"Q: {req.question}", flush=True) + print(f"A: {answer!r} (score={score:.4f})", flush=True) + return {"answer": answer} + + except Exception: + raise HTTPException(status_code=500, detail=f"Processing failed:\n{traceback.format_exc()}") diff --git a/main_qa.py b/main_qa.py new file mode 100644 index 0000000..70cd98c --- /dev/null +++ b/main_qa.py @@ -0,0 +1,25 @@ +import argparse +import uvicorn +from fastapi_qa import app + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_dir", type=str, default="/model", + help="模型目录(挂载到容器内的路径)") + parser.add_argument("--use_gpu", action="store_true", default=True, + help="是否使用 GPU(CUDA)") + parser.add_argument("--port", type=int, default=8000, + help="FastAPI 服务端口,默认 8000") + + args = parser.parse_args() + + app.state.config = { + "model_dir": args.model_dir, + "use_gpu": args.use_gpu, + } + + uvicorn.run("fastapi_qa:app", + host="0.0.0.0", + port=args.port, + workers=1, + ) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..629e89f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +requests +wheel +websocket-client +pydantic>=2.0.0 +numpy<2.0 +PYYaml +fastapi +uvicorn +python-multipart +scipy +sentencepiece \ No newline at end of file