152 lines
2.5 KiB
Markdown
152 lines
2.5 KiB
Markdown
# Sentence Transformer Server
|
||
|
||
基于 FastAPI 和 Sentence Transformers 的文本向量化服务,支持文本编码和相似度计算。
|
||
|
||
## 功能特性
|
||
|
||
- **文本编码**:将文本转换为高维向量表示
|
||
- **相似度计算**:计算两个文本之间的余弦相似度
|
||
- **RESTful API**:提供标准的 HTTP 接口
|
||
|
||
## Docker 部署
|
||
|
||
### 构建镜像
|
||
|
||
```bash
|
||
docker build -t sentence-transformer-server .
|
||
```
|
||
|
||
### 运行容器
|
||
|
||
#### GPU 版本(需要 nvidia-docker)
|
||
|
||
```bash
|
||
docker run -d \
|
||
--name st-server \
|
||
--gpus all \
|
||
-p 8000:8000 \
|
||
-v /path/to/your/model:/model \
|
||
sentence-transformer-server
|
||
```
|
||
|
||
#### CPU 版本
|
||
|
||
```bash
|
||
# 先修改 server.py 中的 DEVICE = "cpu"
|
||
docker run -d \
|
||
--name st-server \
|
||
-p 8000:8000 \
|
||
-v /path/to/your/model:/model \
|
||
sentence-transformer-server
|
||
```
|
||
|
||
**注意**:将 `/path/to/your/model` 替换为实际的模型文件路径
|
||
|
||
## API 接口
|
||
|
||
### 1. 健康检查
|
||
|
||
**接口**:`GET /health`
|
||
|
||
**响应**:
|
||
```json
|
||
{
|
||
"status": "ok"
|
||
}
|
||
```
|
||
|
||
### 2. 文本编码
|
||
|
||
**接口**:`POST /encode`
|
||
|
||
**请求体**:
|
||
```json
|
||
{
|
||
"texts": ["这是一段测试文本", "这是另一段文本"],
|
||
"normalize": true
|
||
}
|
||
```
|
||
|
||
**参数说明**:
|
||
- `texts`:待编码的文本列表
|
||
- `normalize`:是否对向量进行归一化(默认 true)
|
||
|
||
**响应**:
|
||
```json
|
||
{
|
||
"embeddings": [
|
||
[0.123, 0.456, ...],
|
||
[0.789, 0.234, ...]
|
||
]
|
||
}
|
||
```
|
||
|
||
**示例**:
|
||
```bash
|
||
curl -X POST http://localhost:8000/encode \
|
||
-H "Content-Type: application/json" \
|
||
-d '{"texts": ["你好世界", "测试文本"], "normalize": true}'
|
||
```
|
||
|
||
### 3. 相似度计算
|
||
|
||
**接口**:`POST /similarity`
|
||
|
||
**请求体**:
|
||
```json
|
||
{
|
||
"text1": "第一段文本",
|
||
"text2": "第二段文本"
|
||
}
|
||
```
|
||
|
||
**响应**:
|
||
```json
|
||
{
|
||
"similarity": 0.8567
|
||
}
|
||
```
|
||
|
||
**示例**:
|
||
```bash
|
||
curl -X POST http://localhost:8000/similarity \
|
||
-H "Content-Type: application/json" \
|
||
-d '{"text1": "我喜欢吃苹果", "text2": "我爱吃水果"}'
|
||
```
|
||
|
||
## 配置说明
|
||
|
||
### 模型路径
|
||
|
||
模型路径通过容器内的 `/model` 目录挂载,可在 [server.py](server.py#L9) 中修改:
|
||
|
||
```python
|
||
MODEL_NAME = "/model"
|
||
```
|
||
|
||
### 设备配置
|
||
|
||
根据实际硬件环境修改设备配置,[server.py](server.py#L10):
|
||
|
||
```python
|
||
# NVIDIA GPU
|
||
DEVICE = "cuda"
|
||
|
||
# CPU
|
||
DEVICE = "cpu"
|
||
|
||
# 国产芯片(需修改代码支持)
|
||
DEVICE = "npu" # 华为昇腾
|
||
DEVICE = "mlu" # 寒武纪
|
||
```
|
||
|
||
## 依赖包
|
||
|
||
主要依赖项见 [requirements.txt](requirements.txt):
|
||
- fastapi
|
||
- uvicorn
|
||
- pydantic
|
||
- numpy
|
||
- sentence-transformers
|
||
|