initial commit

This commit is contained in:
2026-04-16 10:54:32 +08:00
parent 112ad63ebb
commit 189f999152
6 changed files with 225 additions and 1 deletions

137
fastapi_qa.py Normal file
View File

@@ -0,0 +1,137 @@
import os
import json
import traceback
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline as hf_pipeline
app = FastAPI()
status = "Running"
qa_pipeline = None
CUSTOM_DEVICE = os.getenv("CUSTOM_DEVICE", "")
if CUSTOM_DEVICE.startswith("mlu"):
import torch_mlu
elif CUSTOM_DEVICE.startswith("ascend"):
import torch_npu
elif CUSTOM_DEVICE.startswith("pt"):
import torch_dipu
_DTYPE_MAP = {
"auto": "auto",
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
def _parse_config_json() -> dict:
"""从环境变量 CONFIG_JSON 读取可选配置,未设置时返回空字典。"""
raw = os.getenv("CONFIG_JSON", "").strip()
if not raw:
return {}
try:
return json.loads(raw)
except json.JSONDecodeError as e:
print(f"[WARN] CONFIG_JSON 解析失败,使用默认值: {e}", flush=True)
return {}
@app.on_event("startup")
def load_model():
global status, qa_pipeline
cfg = app.state.config # 来自 main_qa.py
extra = _parse_config_json() # 来自 CONFIG_JSON 环境变量
model_dir = cfg.get("model_dir", "/model")
use_gpu = cfg.get("use_gpu", True)
# ---------- 设备 ----------
device = "cpu"
if use_gpu:
if CUSTOM_DEVICE.startswith("mlu"):
device = "mlu:0"
elif CUSTOM_DEVICE.startswith("ascend"):
device = "npu:0"
else:
device = "cuda:0"
# ---------- torch_dtype ----------
dtype_str = extra.get("torch_dtype", "float32")
torch_dtype = _DTYPE_MAP.get(dtype_str, torch.float32)
# ---------- pipeline 推理参数(透传给每次 __call__----------
# handle_impossible_answer: 支持 SQuAD 2.0 风格模型,预测无答案时返回 answer=""
app.state.handle_impossible_answer = extra.get("handle_impossible_answer", True)
app.state.score_threshold = float(extra.get("score_threshold", 0.0))
app.state.max_answer_len = int(extra.get("max_answer_len", 15))
app.state.max_seq_len = int(extra.get("max_seq_len", 384))
app.state.doc_stride = int(extra.get("doc_stride", 128))
print(">> Startup config:", flush=True)
print(f" model_dir = {model_dir}", flush=True)
print(f" device = {device}", flush=True)
print(f" torch_dtype = {torch_dtype}", flush=True)
print(f" handle_impossible_answer= {app.state.handle_impossible_answer}", flush=True)
print(f" score_threshold = {app.state.score_threshold}", flush=True)
print(f" max_answer_len = {app.state.max_answer_len}", flush=True)
print(f" max_seq_len = {app.state.max_seq_len}", flush=True)
print(f" doc_stride = {app.state.doc_stride}", flush=True)
qa_pipeline = hf_pipeline(
task="question-answering",
model=model_dir,
device=device,
torch_dtype=torch_dtype,
)
status = "Success"
print(">> Model loaded successfully.", flush=True)
class QARequest(BaseModel):
context: str
question: str
@app.get("/health")
def health():
if status == "Running":
return {"status": "loading model"}
return {"status": "ok" if status == "Success" else "failed"}
@app.post("/qa")
def qa(req: QARequest):
if status != "Success":
raise HTTPException(status_code=503, detail="Model not ready")
try:
result = qa_pipeline(
question=req.question,
context=req.context,
handle_impossible_answer=app.state.handle_impossible_answer,
max_answer_len=app.state.max_answer_len,
max_seq_len=app.state.max_seq_len,
doc_stride=app.state.doc_stride,
)
answer = result.get("answer", "")
score = result.get("score", 0.0)
# 两种情况视为无法回答:
# 1. 模型本身预测 no-answeranswer 为空串handle_impossible_answer=True 时触发)
# 2. 置信度低于用户设定的 score_threshold
# 另外对于SQuad 1.1模型,问到反例就让他错因为模型没有处理反例能力一定会给出答案
if not answer or (app.state.handle_impossible_answer and score < app.state.score_threshold):
answer = ""
print(f"Q: {req.question}", flush=True)
print(f"A: {answer!r} (score={score:.4f})", flush=True)
return {"answer": answer}
except Exception:
raise HTTPException(status_code=500, detail=f"Processing failed:\n{traceback.format_exc()}")