diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..e980443 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,16 @@ +FROM git.modelhub.org.cn:9443/enginex-metax/vllm:0.9.1 + +WORKDIR /workspace + +# 复制 requirements.txt 并安装 Python 依赖 +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# 复制 server.py 到 workspace +COPY server.py /workspace/ + +# 暴露端口 +EXPOSE 8000 + +# 启动服务 +CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/README.md b/README.md index ce3ca28..d6e170f 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,151 @@ -# enginex_metax_c_series-feature-extraction +# Sentence Transformer Server + +基于 FastAPI 和 Sentence Transformers 的文本向量化服务,支持文本编码和相似度计算。 + +## 功能特性 + +- **文本编码**:将文本转换为高维向量表示 +- **相似度计算**:计算两个文本之间的余弦相似度 +- **RESTful API**:提供标准的 HTTP 接口 + +## Docker 部署 + +### 构建镜像 + +```bash +docker build -t sentence-transformer-server . +``` + +### 运行容器 + +#### GPU 版本(需要 nvidia-docker) + +```bash +docker run -d \ + --name st-server \ + --gpus all \ + -p 8000:8000 \ + -v /path/to/your/model:/model \ + sentence-transformer-server +``` + +#### CPU 版本 + +```bash +# 先修改 server.py 中的 DEVICE = "cpu" +docker run -d \ + --name st-server \ + -p 8000:8000 \ + -v /path/to/your/model:/model \ + sentence-transformer-server +``` + +**注意**:将 `/path/to/your/model` 替换为实际的模型文件路径 + +## API 接口 + +### 1. 健康检查 + +**接口**:`GET /health` + +**响应**: +```json +{ + "status": "ok" +} +``` + +### 2. 文本编码 + +**接口**:`POST /encode` + +**请求体**: +```json +{ + "texts": ["这是一段测试文本", "这是另一段文本"], + "normalize": true +} +``` + +**参数说明**: +- `texts`:待编码的文本列表 +- `normalize`:是否对向量进行归一化(默认 true) + +**响应**: +```json +{ + "embeddings": [ + [0.123, 0.456, ...], + [0.789, 0.234, ...] + ] +} +``` + +**示例**: +```bash +curl -X POST http://localhost:8000/encode \ + -H "Content-Type: application/json" \ + -d '{"texts": ["你好世界", "测试文本"], "normalize": true}' +``` + +### 3. 相似度计算 + +**接口**:`POST /similarity` + +**请求体**: +```json +{ + "text1": "第一段文本", + "text2": "第二段文本" +} +``` + +**响应**: +```json +{ + "similarity": 0.8567 +} +``` + +**示例**: +```bash +curl -X POST http://localhost:8000/similarity \ + -H "Content-Type: application/json" \ + -d '{"text1": "我喜欢吃苹果", "text2": "我爱吃水果"}' +``` + +## 配置说明 + +### 模型路径 + +模型路径通过容器内的 `/model` 目录挂载,可在 [server.py](server.py#L9) 中修改: + +```python +MODEL_NAME = "/model" +``` + +### 设备配置 + +根据实际硬件环境修改设备配置,[server.py](server.py#L10): + +```python +# NVIDIA GPU +DEVICE = "cuda" + +# CPU +DEVICE = "cpu" + +# 国产芯片(需修改代码支持) +DEVICE = "npu" # 华为昇腾 +DEVICE = "mlu" # 寒武纪 +``` + +## 依赖包 + +主要依赖项见 [requirements.txt](requirements.txt): +- fastapi +- uvicorn +- pydantic +- numpy +- sentence-transformers diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..9e8e05f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +fastapi==0.116.1 +uvicorn==0.35.0 +pydantic==2.11.7 +numpy==1.26.4 +sentence-transformers==5.3.0 diff --git a/server.py b/server.py new file mode 100644 index 0000000..75aad61 --- /dev/null +++ b/server.py @@ -0,0 +1,51 @@ +from fastapi import FastAPI +from pydantic import BaseModel +from typing import List +import numpy as np + +from sentence_transformers import SentenceTransformer + +# ===== 配置 ===== +MODEL_NAME = "/model" +DEVICE = "cuda" # 改成国产卡设备,例如 "npu" / "mlu" / "cpu" + +# ===== 加载模型 ===== +model = SentenceTransformer(MODEL_NAME, device=DEVICE) + +app = FastAPI() + +# ===== 请求结构 ===== +class EncodeRequest(BaseModel): + texts: List[str] + normalize: bool = True + +class SimilarityRequest(BaseModel): + text1: str + text2: str + +# ===== 工具函数 ===== +def cosine(a, b): + return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))) + +# ===== 接口 ===== +@app.post("/encode") +def encode(req: EncodeRequest): + embeddings = model.encode( + req.texts, + normalize_embeddings=req.normalize + ) + return { + "embeddings": embeddings.tolist() + } + +@app.post("/similarity") +def similarity(req: SimilarityRequest): + emb = model.encode([req.text1, req.text2], normalize_embeddings=True) + sim = cosine(emb[0], emb[1]) + return { + "similarity": sim + } + +@app.get("/health") +def health(): + return {"status": "ok"} \ No newline at end of file