init project
This commit is contained in:
13
Dockerfile
Normal file
13
Dockerfile
Normal file
@@ -0,0 +1,13 @@
|
||||
FROM git.modelhub.org.cn:9443/enginex-iluvatar-bi150/vllm:0.8.3
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
COPY requirements.txt /workspace
|
||||
RUN pip install --no-cache-dir -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
ADD . /workspace
|
||||
|
||||
EXPOSE 8000
|
||||
CMD ["sh","-c","python3 transformers_server.py"]
|
||||
|
||||
|
||||
19
README.md
19
README.md
@@ -1,2 +1,21 @@
|
||||
# enginex-bi_150-text-classification
|
||||
## Quickstart
|
||||
### 启动服务
|
||||
修改docker.sh的脚本中$mountpath为本地的模型挂载路径
|
||||
然后运行./docker.sh
|
||||
当打印出以下内容时表示模型load成功
|
||||
```
|
||||
2026-04-10 10:01:00 /workspace/transformers_server.py INFO model loaded successfully
|
||||
INFO: Application startup complete.
|
||||
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
|
||||
```
|
||||
|
||||
### 运行测试
|
||||
在容器中执行 python3 test.py
|
||||
打印出以下内容, 文本分类结果的id和lable:
|
||||
```
|
||||
text: 这家餐厅的服务特别周到,菜也很好吃,我下次还会再来。
|
||||
predicted_class_id: 4
|
||||
predicted_label: 5 stars
|
||||
```
|
||||
|
||||
|
||||
6
docker.sh
Executable file
6
docker.sh
Executable file
@@ -0,0 +1,6 @@
|
||||
docker stop bi150_tc
|
||||
docker rm bi150_tc
|
||||
|
||||
docker build . -t bi150_text_classification
|
||||
|
||||
docker run --rm -p 17777:8000 -v $mountpath:/model:ro -it --device=/dev/iluvatar0:/dev/iluvatar0 --name bi150_tc -e CONFIG_JSON='{"model_class": "AutoModelForSequenceClassification", "tokenizer_class": "AutoTokenizer", "torch_dtype": "auto"}' bi150_text_classification
|
||||
13
logger.py
Normal file
13
logger.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import logging
|
||||
import os
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
level=os.environ.get("LOGLEVEL", "INFO"),
|
||||
)
|
||||
|
||||
def get_logger(file):
|
||||
return logging.getLogger(file)
|
||||
|
||||
2
requirements.txt
Normal file
2
requirements.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
fastapi
|
||||
uvicorn
|
||||
25
test.py
Normal file
25
test.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Send a test request to /v1/classify")
|
||||
parser.add_argument("--url", default="http://127.0.0.1:8000/v1/classify")
|
||||
parser.add_argument("--text", default="这家餐厅的服务特别周到,菜也很好吃,我下次还会再来。")
|
||||
parser.add_argument("--timeout", type=int, default=60)
|
||||
args = parser.parse_args()
|
||||
|
||||
response = requests.post(
|
||||
args.url,
|
||||
json={"text": args.text},
|
||||
timeout=args.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
print(json.dumps(response.json(), ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
129
transformers_server.py
Normal file
129
transformers_server.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
import logger
|
||||
logger = logger.get_logger(__file__)
|
||||
|
||||
app = FastAPI()
|
||||
initialized = False
|
||||
tokenizer = None
|
||||
model = None
|
||||
device = None
|
||||
|
||||
|
||||
def load_config():
|
||||
raw_config = os.environ.get("CONFIG_JSON", "").strip()
|
||||
if not raw_config:
|
||||
return {}
|
||||
|
||||
try:
|
||||
config = json.loads(raw_config)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError("CONFIG_JSON is not valid JSON") from exc
|
||||
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError("CONFIG_JSON must be a JSON object")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def resolve_transformers_class(class_name, default_name):
|
||||
resolved_name = class_name or default_name
|
||||
resolved_class = getattr(transformers, resolved_name, None)
|
||||
if resolved_class is None:
|
||||
raise ValueError(f"Unsupported transformers class: {resolved_name}")
|
||||
return resolved_name, resolved_class
|
||||
|
||||
|
||||
def resolve_torch_dtype(dtype_name, default_name="float16"):
|
||||
resolved_name = dtype_name or default_name
|
||||
if resolved_name == "auto":
|
||||
return resolved_name, "auto"
|
||||
resolved_dtype = getattr(torch, resolved_name, None)
|
||||
if resolved_dtype is None:
|
||||
raise ValueError(f"Unsupported torch dtype: {resolved_name}")
|
||||
return resolved_name, resolved_dtype
|
||||
|
||||
|
||||
class ClassifyRequest(BaseModel):
|
||||
text: str
|
||||
|
||||
@app.get("/")
|
||||
def read_root():
|
||||
return {"message": "Hello, World!"}
|
||||
|
||||
@app.on_event("startup")
|
||||
def load_model():
|
||||
logger.info("loading model")
|
||||
global initialized, tokenizer, model, device
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA is required but is not available")
|
||||
|
||||
device = torch.device("cuda")
|
||||
|
||||
model_path = "/model"
|
||||
config = load_config()
|
||||
|
||||
tokenizer_class_name, tokenizer_class = resolve_transformers_class(
|
||||
config.get("tokenizer_class"),
|
||||
"AutoTokenizer",
|
||||
)
|
||||
model_class_name, model_class = resolve_transformers_class(
|
||||
config.get("model_class"),
|
||||
"AutoModelForSequenceClassification",
|
||||
)
|
||||
torch_dtype_name, torch_dtype = resolve_torch_dtype(config.get("torch_dtype"))
|
||||
|
||||
logger.info(
|
||||
"resolved config: "
|
||||
f"model_class={model_class_name}, "
|
||||
f"tokenizer_class={tokenizer_class_name}, "
|
||||
f"torch_dtype={torch_dtype_name}"
|
||||
)
|
||||
|
||||
tokenizer = tokenizer_class.from_pretrained(model_path)
|
||||
model = model_class.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
initialized = True
|
||||
logger.info("model loaded successfully")
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def get_status():
|
||||
logger.info(f"get status, initialized={initialized}")
|
||||
return initialized
|
||||
|
||||
@app.post("/v1/classify")
|
||||
async def classify(request: ClassifyRequest):
|
||||
text = request.text
|
||||
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
||||
inputs = {key: value.to(device) for key, value in inputs.items()}
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
predicted_class_id = outputs.logits.argmax(dim=-1).item()
|
||||
id2label = model.config.id2label
|
||||
predicted_label = id2label.get(predicted_class_id, str(predicted_class_id))
|
||||
|
||||
logger.info(f"text: {text}")
|
||||
logger.info(f"predicted_class_id: {predicted_class_id}")
|
||||
logger.info(f"predicted_label: {predicted_label}")
|
||||
|
||||
return {"label": predicted_label}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
uvicorn.run("transformers_server:app", host="0.0.0.0", port=8000, workers=1, access_log=False)
|
||||
|
||||
Reference in New Issue
Block a user