Files
2026-04-14 19:01:27 +08:00

51 lines
1.2 KiB
Python

from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
import numpy as np
from sentence_transformers import SentenceTransformer
# ===== 配置 =====
MODEL_NAME = "/model"
DEVICE = "cuda" # 改成国产卡设备,例如 "npu" / "mlu" / "cpu"
# ===== 加载模型 =====
model = SentenceTransformer(MODEL_NAME, device=DEVICE)
app = FastAPI()
# ===== 请求结构 =====
class EncodeRequest(BaseModel):
texts: List[str]
normalize: bool = True
class SimilarityRequest(BaseModel):
text1: str
text2: str
# ===== 工具函数 =====
def cosine(a, b):
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
# ===== 接口 =====
@app.post("/encode")
def encode(req: EncodeRequest):
embeddings = model.encode(
req.texts,
normalize_embeddings=req.normalize
)
return {
"embeddings": embeddings.tolist()
}
@app.post("/similarity")
def similarity(req: SimilarityRequest):
emb = model.encode([req.text1, req.text2], normalize_embeddings=True)
sim = cosine(emb[0], emb[1])
return {
"similarity": sim
}
@app.get("/health")
def health():
return {"status": "ok"}