Files
2026-04-16 10:54:32 +08:00

138 lines
4.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
import traceback
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline as hf_pipeline
app = FastAPI()
status = "Running"
qa_pipeline = None
CUSTOM_DEVICE = os.getenv("CUSTOM_DEVICE", "")
if CUSTOM_DEVICE.startswith("mlu"):
import torch_mlu
elif CUSTOM_DEVICE.startswith("ascend"):
import torch_npu
elif CUSTOM_DEVICE.startswith("pt"):
import torch_dipu
_DTYPE_MAP = {
"auto": "auto",
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
def _parse_config_json() -> dict:
"""从环境变量 CONFIG_JSON 读取可选配置,未设置时返回空字典。"""
raw = os.getenv("CONFIG_JSON", "").strip()
if not raw:
return {}
try:
return json.loads(raw)
except json.JSONDecodeError as e:
print(f"[WARN] CONFIG_JSON 解析失败,使用默认值: {e}", flush=True)
return {}
@app.on_event("startup")
def load_model():
global status, qa_pipeline
cfg = app.state.config # 来自 main_qa.py
extra = _parse_config_json() # 来自 CONFIG_JSON 环境变量
model_dir = cfg.get("model_dir", "/model")
use_gpu = cfg.get("use_gpu", True)
# ---------- 设备 ----------
device = "cpu"
if use_gpu:
if CUSTOM_DEVICE.startswith("mlu"):
device = "mlu:0"
elif CUSTOM_DEVICE.startswith("ascend"):
device = "npu:0"
else:
device = "cuda:0"
# ---------- torch_dtype ----------
dtype_str = extra.get("torch_dtype", "float32")
torch_dtype = _DTYPE_MAP.get(dtype_str, torch.float32)
# ---------- pipeline 推理参数(透传给每次 __call__----------
# handle_impossible_answer: 支持 SQuAD 2.0 风格模型,预测无答案时返回 answer=""
app.state.handle_impossible_answer = extra.get("handle_impossible_answer", True)
app.state.score_threshold = float(extra.get("score_threshold", 0.0))
app.state.max_answer_len = int(extra.get("max_answer_len", 15))
app.state.max_seq_len = int(extra.get("max_seq_len", 384))
app.state.doc_stride = int(extra.get("doc_stride", 128))
print(">> Startup config:", flush=True)
print(f" model_dir = {model_dir}", flush=True)
print(f" device = {device}", flush=True)
print(f" torch_dtype = {torch_dtype}", flush=True)
print(f" handle_impossible_answer= {app.state.handle_impossible_answer}", flush=True)
print(f" score_threshold = {app.state.score_threshold}", flush=True)
print(f" max_answer_len = {app.state.max_answer_len}", flush=True)
print(f" max_seq_len = {app.state.max_seq_len}", flush=True)
print(f" doc_stride = {app.state.doc_stride}", flush=True)
qa_pipeline = hf_pipeline(
task="question-answering",
model=model_dir,
device=device,
torch_dtype=torch_dtype,
)
status = "Success"
print(">> Model loaded successfully.", flush=True)
class QARequest(BaseModel):
context: str
question: str
@app.get("/health")
def health():
if status == "Running":
return {"status": "loading model"}
return {"status": "ok" if status == "Success" else "failed"}
@app.post("/qa")
def qa(req: QARequest):
if status != "Success":
raise HTTPException(status_code=503, detail="Model not ready")
try:
result = qa_pipeline(
question=req.question,
context=req.context,
handle_impossible_answer=app.state.handle_impossible_answer,
max_answer_len=app.state.max_answer_len,
max_seq_len=app.state.max_seq_len,
doc_stride=app.state.doc_stride,
)
answer = result.get("answer", "")
score = result.get("score", 0.0)
# 两种情况视为无法回答:
# 1. 模型本身预测 no-answeranswer 为空串handle_impossible_answer=True 时触发)
# 2. 置信度低于用户设定的 score_threshold
# 另外对于SQuad 1.1模型,问到反例就让他错因为模型没有处理反例能力一定会给出答案
if not answer or (app.state.handle_impossible_answer and score < app.state.score_threshold):
answer = ""
print(f"Q: {req.question}", flush=True)
print(f"A: {answer!r} (score={score:.4f})", flush=True)
return {"answer": answer}
except Exception:
raise HTTPException(status_code=500, detail=f"Processing failed:\n{traceback.format_exc()}")