initial commit
This commit is contained in:
1
.dockerignore
Normal file
1
.dockerignore
Normal file
@@ -0,0 +1 @@
|
|||||||
|
test_scripts/
|
||||||
13
Dockerfile.qa_bi150
Normal file
13
Dockerfile.qa_bi150
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
FROM corex:4.3.8
|
||||||
|
|
||||||
|
WORKDIR /root
|
||||||
|
|
||||||
|
ADD . /root/
|
||||||
|
|
||||||
|
COPY requirements.txt /root
|
||||||
|
RUN pip install -r requirements.txt -i https://nexus.4pd.io/repository/pypi-all/simple
|
||||||
|
# 安装torch是为了提供cuda库环境
|
||||||
|
RUN pip install transformers==4.51.3 -i https://nexus.4pd.io/repository/pypi-all/simple
|
||||||
|
|
||||||
|
ENTRYPOINT ["python3"]
|
||||||
|
CMD ["./main_qa.py"]
|
||||||
39
README.md
39
README.md
@@ -1,2 +1,39 @@
|
|||||||
# enginex-bi_150-question-answering
|
# 天数智芯 天垓150 文本问答
|
||||||
|
|
||||||
|
## 镜像构造
|
||||||
|
```shell
|
||||||
|
docker build -f ./Dockerfile.qa_bi150 -t <your_image> .
|
||||||
|
```
|
||||||
|
其中,基础镜像 corex:4.3.8 通过联系天数智芯智铠100厂商技术支持可获取
|
||||||
|
|
||||||
|
## 使用说明
|
||||||
|
|
||||||
|
### 使用 FastAPI 启动文本问答的服务:
|
||||||
|
例如:
|
||||||
|
```shell
|
||||||
|
docker run -dit -v /usr/src:/usr/src -v /lib/modules:/lib/modules --device=/dev/iluvatar0:/dev/iluvatar0 \
|
||||||
|
-v /mnt/contest_ceph/leaderboard/modelHubXC/csarron/bert-base-uncased-squad-v1:/model \
|
||||||
|
--network=host -e CONFIG_JSON='{
|
||||||
|
"torch_dtype": "auto",
|
||||||
|
"handle_impossible_answer": false,
|
||||||
|
"score_threshold": 0.0,
|
||||||
|
"max_answer_len": 30,
|
||||||
|
"max_seq_len": 384,
|
||||||
|
"doc_stride": 128
|
||||||
|
}' \
|
||||||
|
--entrypoint=python3 <your_image> \
|
||||||
|
main_qa.py --model_dir /model --port 1111
|
||||||
|
```
|
||||||
|
具体参数代码设定可参考代码文件
|
||||||
|
|
||||||
|
### 测试服务
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl -X POST http://localhost:1111/qa \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"context": "The capital city of China is Beijing",
|
||||||
|
"question": "What is the capital city of China?"
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|||||||
137
fastapi_qa.py
Normal file
137
fastapi_qa.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import traceback
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from transformers import pipeline as hf_pipeline
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
status = "Running"
|
||||||
|
qa_pipeline = None
|
||||||
|
|
||||||
|
CUSTOM_DEVICE = os.getenv("CUSTOM_DEVICE", "")
|
||||||
|
if CUSTOM_DEVICE.startswith("mlu"):
|
||||||
|
import torch_mlu
|
||||||
|
elif CUSTOM_DEVICE.startswith("ascend"):
|
||||||
|
import torch_npu
|
||||||
|
elif CUSTOM_DEVICE.startswith("pt"):
|
||||||
|
import torch_dipu
|
||||||
|
|
||||||
|
_DTYPE_MAP = {
|
||||||
|
"auto": "auto",
|
||||||
|
"float32": torch.float32,
|
||||||
|
"float16": torch.float16,
|
||||||
|
"bfloat16": torch.bfloat16,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_config_json() -> dict:
|
||||||
|
"""从环境变量 CONFIG_JSON 读取可选配置,未设置时返回空字典。"""
|
||||||
|
raw = os.getenv("CONFIG_JSON", "").strip()
|
||||||
|
if not raw:
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
return json.loads(raw)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"[WARN] CONFIG_JSON 解析失败,使用默认值: {e}", flush=True)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
def load_model():
|
||||||
|
global status, qa_pipeline
|
||||||
|
|
||||||
|
cfg = app.state.config # 来自 main_qa.py
|
||||||
|
extra = _parse_config_json() # 来自 CONFIG_JSON 环境变量
|
||||||
|
|
||||||
|
model_dir = cfg.get("model_dir", "/model")
|
||||||
|
use_gpu = cfg.get("use_gpu", True)
|
||||||
|
|
||||||
|
# ---------- 设备 ----------
|
||||||
|
device = "cpu"
|
||||||
|
if use_gpu:
|
||||||
|
if CUSTOM_DEVICE.startswith("mlu"):
|
||||||
|
device = "mlu:0"
|
||||||
|
elif CUSTOM_DEVICE.startswith("ascend"):
|
||||||
|
device = "npu:0"
|
||||||
|
else:
|
||||||
|
device = "cuda:0"
|
||||||
|
|
||||||
|
# ---------- torch_dtype ----------
|
||||||
|
dtype_str = extra.get("torch_dtype", "float32")
|
||||||
|
torch_dtype = _DTYPE_MAP.get(dtype_str, torch.float32)
|
||||||
|
|
||||||
|
# ---------- pipeline 推理参数(透传给每次 __call__)----------
|
||||||
|
# handle_impossible_answer: 支持 SQuAD 2.0 风格模型,预测无答案时返回 answer=""
|
||||||
|
app.state.handle_impossible_answer = extra.get("handle_impossible_answer", True)
|
||||||
|
app.state.score_threshold = float(extra.get("score_threshold", 0.0))
|
||||||
|
app.state.max_answer_len = int(extra.get("max_answer_len", 15))
|
||||||
|
app.state.max_seq_len = int(extra.get("max_seq_len", 384))
|
||||||
|
app.state.doc_stride = int(extra.get("doc_stride", 128))
|
||||||
|
|
||||||
|
print(">> Startup config:", flush=True)
|
||||||
|
print(f" model_dir = {model_dir}", flush=True)
|
||||||
|
print(f" device = {device}", flush=True)
|
||||||
|
print(f" torch_dtype = {torch_dtype}", flush=True)
|
||||||
|
print(f" handle_impossible_answer= {app.state.handle_impossible_answer}", flush=True)
|
||||||
|
print(f" score_threshold = {app.state.score_threshold}", flush=True)
|
||||||
|
print(f" max_answer_len = {app.state.max_answer_len}", flush=True)
|
||||||
|
print(f" max_seq_len = {app.state.max_seq_len}", flush=True)
|
||||||
|
print(f" doc_stride = {app.state.doc_stride}", flush=True)
|
||||||
|
|
||||||
|
qa_pipeline = hf_pipeline(
|
||||||
|
task="question-answering",
|
||||||
|
model=model_dir,
|
||||||
|
device=device,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
status = "Success"
|
||||||
|
print(">> Model loaded successfully.", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
class QARequest(BaseModel):
|
||||||
|
context: str
|
||||||
|
question: str
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
def health():
|
||||||
|
if status == "Running":
|
||||||
|
return {"status": "loading model"}
|
||||||
|
return {"status": "ok" if status == "Success" else "failed"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/qa")
|
||||||
|
def qa(req: QARequest):
|
||||||
|
if status != "Success":
|
||||||
|
raise HTTPException(status_code=503, detail="Model not ready")
|
||||||
|
try:
|
||||||
|
result = qa_pipeline(
|
||||||
|
question=req.question,
|
||||||
|
context=req.context,
|
||||||
|
handle_impossible_answer=app.state.handle_impossible_answer,
|
||||||
|
max_answer_len=app.state.max_answer_len,
|
||||||
|
max_seq_len=app.state.max_seq_len,
|
||||||
|
doc_stride=app.state.doc_stride,
|
||||||
|
)
|
||||||
|
|
||||||
|
answer = result.get("answer", "")
|
||||||
|
score = result.get("score", 0.0)
|
||||||
|
|
||||||
|
# 两种情况视为无法回答:
|
||||||
|
# 1. 模型本身预测 no-answer(answer 为空串,handle_impossible_answer=True 时触发)
|
||||||
|
# 2. 置信度低于用户设定的 score_threshold
|
||||||
|
# 另外对于SQuad 1.1模型,问到反例就让他错因为模型没有处理反例能力一定会给出答案
|
||||||
|
if not answer or (app.state.handle_impossible_answer and score < app.state.score_threshold):
|
||||||
|
answer = ""
|
||||||
|
|
||||||
|
print(f"Q: {req.question}", flush=True)
|
||||||
|
print(f"A: {answer!r} (score={score:.4f})", flush=True)
|
||||||
|
return {"answer": answer}
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(status_code=500, detail=f"Processing failed:\n{traceback.format_exc()}")
|
||||||
25
main_qa.py
Normal file
25
main_qa.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
import argparse
|
||||||
|
import uvicorn
|
||||||
|
from fastapi_qa import app
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--model_dir", type=str, default="/model",
|
||||||
|
help="模型目录(挂载到容器内的路径)")
|
||||||
|
parser.add_argument("--use_gpu", action="store_true", default=True,
|
||||||
|
help="是否使用 GPU(CUDA)")
|
||||||
|
parser.add_argument("--port", type=int, default=8000,
|
||||||
|
help="FastAPI 服务端口,默认 8000")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
app.state.config = {
|
||||||
|
"model_dir": args.model_dir,
|
||||||
|
"use_gpu": args.use_gpu,
|
||||||
|
}
|
||||||
|
|
||||||
|
uvicorn.run("fastapi_qa:app",
|
||||||
|
host="0.0.0.0",
|
||||||
|
port=args.port,
|
||||||
|
workers=1,
|
||||||
|
)
|
||||||
11
requirements.txt
Normal file
11
requirements.txt
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
requests
|
||||||
|
wheel
|
||||||
|
websocket-client
|
||||||
|
pydantic>=2.0.0
|
||||||
|
numpy<2.0
|
||||||
|
PYYaml
|
||||||
|
fastapi
|
||||||
|
uvicorn
|
||||||
|
python-multipart
|
||||||
|
scipy
|
||||||
|
sentencepiece
|
||||||
Reference in New Issue
Block a user