51 lines
1.2 KiB
Python
51 lines
1.2 KiB
Python
|
|
from fastapi import FastAPI
|
||
|
|
from pydantic import BaseModel
|
||
|
|
from typing import List
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
from sentence_transformers import SentenceTransformer
|
||
|
|
|
||
|
|
# ===== 配置 =====
|
||
|
|
MODEL_NAME = "/model"
|
||
|
|
DEVICE = "cuda" # 改成国产卡设备,例如 "npu" / "mlu" / "cpu"
|
||
|
|
|
||
|
|
# ===== 加载模型 =====
|
||
|
|
model = SentenceTransformer(MODEL_NAME, device=DEVICE)
|
||
|
|
|
||
|
|
app = FastAPI()
|
||
|
|
|
||
|
|
# ===== 请求结构 =====
|
||
|
|
class EncodeRequest(BaseModel):
|
||
|
|
texts: List[str]
|
||
|
|
normalize: bool = True
|
||
|
|
|
||
|
|
class SimilarityRequest(BaseModel):
|
||
|
|
text1: str
|
||
|
|
text2: str
|
||
|
|
|
||
|
|
# ===== 工具函数 =====
|
||
|
|
def cosine(a, b):
|
||
|
|
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
|
||
|
|
|
||
|
|
# ===== 接口 =====
|
||
|
|
@app.post("/encode")
|
||
|
|
def encode(req: EncodeRequest):
|
||
|
|
embeddings = model.encode(
|
||
|
|
req.texts,
|
||
|
|
normalize_embeddings=req.normalize
|
||
|
|
)
|
||
|
|
return {
|
||
|
|
"embeddings": embeddings.tolist()
|
||
|
|
}
|
||
|
|
|
||
|
|
@app.post("/similarity")
|
||
|
|
def similarity(req: SimilarityRequest):
|
||
|
|
emb = model.encode([req.text1, req.text2], normalize_embeddings=True)
|
||
|
|
sim = cosine(emb[0], emb[1])
|
||
|
|
return {
|
||
|
|
"similarity": sim
|
||
|
|
}
|
||
|
|
|
||
|
|
@app.get("/health")
|
||
|
|
def health():
|
||
|
|
return {"status": "ok"}
|