from fastapi import FastAPI from pydantic import BaseModel from typing import List import numpy as np from sentence_transformers import SentenceTransformer # ===== 配置 ===== MODEL_NAME = "/model" DEVICE = "cuda" # 改成国产卡设备,例如 "npu" / "mlu" / "cpu" # ===== 加载模型 ===== model = SentenceTransformer(MODEL_NAME, device=DEVICE) app = FastAPI() # ===== 请求结构 ===== class EncodeRequest(BaseModel): texts: List[str] normalize: bool = True class SimilarityRequest(BaseModel): text1: str text2: str # ===== 工具函数 ===== def cosine(a, b): return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))) # ===== 接口 ===== @app.post("/encode") def encode(req: EncodeRequest): embeddings = model.encode( req.texts, normalize_embeddings=req.normalize ) return { "embeddings": embeddings.tolist() } @app.post("/similarity") def similarity(req: SimilarityRequest): emb = model.encode([req.text1, req.text2], normalize_embeddings=True) sim = cosine(emb[0], emb[1]) return { "similarity": sim } @app.get("/health") def health(): return {"status": "ok"}