import os import json import traceback import torch from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import pipeline as hf_pipeline app = FastAPI() status = "Running" qa_pipeline = None CUSTOM_DEVICE = os.getenv("CUSTOM_DEVICE", "") if CUSTOM_DEVICE.startswith("mlu"): import torch_mlu elif CUSTOM_DEVICE.startswith("ascend"): import torch_npu elif CUSTOM_DEVICE.startswith("pt"): import torch_dipu _DTYPE_MAP = { "auto": "auto", "float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, } def _parse_config_json() -> dict: """从环境变量 CONFIG_JSON 读取可选配置,未设置时返回空字典。""" raw = os.getenv("CONFIG_JSON", "").strip() if not raw: return {} try: return json.loads(raw) except json.JSONDecodeError as e: print(f"[WARN] CONFIG_JSON 解析失败,使用默认值: {e}", flush=True) return {} @app.on_event("startup") def load_model(): global status, qa_pipeline cfg = app.state.config # 来自 main_qa.py extra = _parse_config_json() # 来自 CONFIG_JSON 环境变量 model_dir = cfg.get("model_dir", "/model") use_gpu = cfg.get("use_gpu", True) # ---------- 设备 ---------- device = "cpu" if use_gpu: if CUSTOM_DEVICE.startswith("mlu"): device = "mlu:0" elif CUSTOM_DEVICE.startswith("ascend"): device = "npu:0" else: device = "cuda:0" # ---------- torch_dtype ---------- dtype_str = extra.get("torch_dtype", "float32") torch_dtype = _DTYPE_MAP.get(dtype_str, torch.float32) # ---------- pipeline 推理参数(透传给每次 __call__)---------- # handle_impossible_answer: 支持 SQuAD 2.0 风格模型,预测无答案时返回 answer="" app.state.handle_impossible_answer = extra.get("handle_impossible_answer", True) app.state.score_threshold = float(extra.get("score_threshold", 0.0)) app.state.max_answer_len = int(extra.get("max_answer_len", 15)) app.state.max_seq_len = int(extra.get("max_seq_len", 384)) app.state.doc_stride = int(extra.get("doc_stride", 128)) print(">> Startup config:", flush=True) print(f" model_dir = {model_dir}", flush=True) print(f" device = {device}", flush=True) print(f" torch_dtype = {torch_dtype}", flush=True) print(f" handle_impossible_answer= {app.state.handle_impossible_answer}", flush=True) print(f" score_threshold = {app.state.score_threshold}", flush=True) print(f" max_answer_len = {app.state.max_answer_len}", flush=True) print(f" max_seq_len = {app.state.max_seq_len}", flush=True) print(f" doc_stride = {app.state.doc_stride}", flush=True) qa_pipeline = hf_pipeline( task="question-answering", model=model_dir, device=device, torch_dtype=torch_dtype, ) status = "Success" print(">> Model loaded successfully.", flush=True) class QARequest(BaseModel): context: str question: str @app.get("/health") def health(): if status == "Running": return {"status": "loading model"} return {"status": "ok" if status == "Success" else "failed"} @app.post("/qa") def qa(req: QARequest): if status != "Success": raise HTTPException(status_code=503, detail="Model not ready") try: result = qa_pipeline( question=req.question, context=req.context, handle_impossible_answer=app.state.handle_impossible_answer, max_answer_len=app.state.max_answer_len, max_seq_len=app.state.max_seq_len, doc_stride=app.state.doc_stride, ) answer = result.get("answer", "") score = result.get("score", 0.0) # 两种情况视为无法回答: # 1. 模型本身预测 no-answer(answer 为空串,handle_impossible_answer=True 时触发) # 2. 置信度低于用户设定的 score_threshold # 另外对于SQuad 1.1模型,问到反例就让他错因为模型没有处理反例能力一定会给出答案 if not answer or (app.state.handle_impossible_answer and score < app.state.score_threshold): answer = "" print(f"Q: {req.question}", flush=True) print(f"A: {answer!r} (score={score:.4f})", flush=True) return {"answer": answer} except Exception: raise HTTPException(status_code=500, detail=f"Processing failed:\n{traceback.format_exc()}")