add metax sentence-transformers engine
Signed-off-by: Sun Ruoxi <sunruoxi@4paradigm.com>
This commit is contained in:
51
server.py
Normal file
51
server.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
import numpy as np
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# ===== 配置 =====
|
||||
MODEL_NAME = "/model"
|
||||
DEVICE = "cuda" # 改成国产卡设备,例如 "npu" / "mlu" / "cpu"
|
||||
|
||||
# ===== 加载模型 =====
|
||||
model = SentenceTransformer(MODEL_NAME, device=DEVICE)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# ===== 请求结构 =====
|
||||
class EncodeRequest(BaseModel):
|
||||
texts: List[str]
|
||||
normalize: bool = True
|
||||
|
||||
class SimilarityRequest(BaseModel):
|
||||
text1: str
|
||||
text2: str
|
||||
|
||||
# ===== 工具函数 =====
|
||||
def cosine(a, b):
|
||||
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
|
||||
|
||||
# ===== 接口 =====
|
||||
@app.post("/encode")
|
||||
def encode(req: EncodeRequest):
|
||||
embeddings = model.encode(
|
||||
req.texts,
|
||||
normalize_embeddings=req.normalize
|
||||
)
|
||||
return {
|
||||
"embeddings": embeddings.tolist()
|
||||
}
|
||||
|
||||
@app.post("/similarity")
|
||||
def similarity(req: SimilarityRequest):
|
||||
emb = model.encode([req.text1, req.text2], normalize_embeddings=True)
|
||||
sim = cosine(emb[0], emb[1])
|
||||
return {
|
||||
"similarity": sim
|
||||
}
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"status": "ok"}
|
||||
Reference in New Issue
Block a user