From 5a21e8d6a29bef0a035116569ae6776d10e7fe7c Mon Sep 17 00:00:00 2001 From: aiyueqi Date: Fri, 10 Apr 2026 10:05:19 +0000 Subject: [PATCH] init project --- Dockerfile | 13 +++++ README.md | 19 ++++++ docker.sh | 6 ++ logger.py | 13 +++++ requirements.txt | 2 + test.py | 25 ++++++++ transformers_server.py | 129 +++++++++++++++++++++++++++++++++++++++++ 7 files changed, 207 insertions(+) create mode 100644 Dockerfile create mode 100755 docker.sh create mode 100644 logger.py create mode 100644 requirements.txt create mode 100644 test.py create mode 100644 transformers_server.py diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..165b624 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,13 @@ +FROM git.modelhub.org.cn:9443/enginex-iluvatar-bi150/vllm:0.8.3 + +WORKDIR /workspace + +COPY requirements.txt /workspace +RUN pip install --no-cache-dir -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple + +ADD . /workspace + +EXPOSE 8000 +CMD ["sh","-c","python3 transformers_server.py"] + + diff --git a/README.md b/README.md index 5b851bf..399199e 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,21 @@ # enginex-bi_150-text-classification +## Quickstart +### 启动服务 +修改docker.sh的脚本中$mountpath为本地的模型挂载路径 +然后运行./docker.sh +当打印出以下内容时表示模型load成功 +``` +2026-04-10 10:01:00 /workspace/transformers_server.py INFO model loaded successfully +INFO: Application startup complete. +INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) +``` + +### 运行测试 +在容器中执行 python3 test.py +打印出以下内容, 文本分类结果的id和lable: +``` +text: 这家餐厅的服务特别周到,菜也很好吃,我下次还会再来。 +predicted_class_id: 4 +predicted_label: 5 stars +``` diff --git a/docker.sh b/docker.sh new file mode 100755 index 0000000..3567ff2 --- /dev/null +++ b/docker.sh @@ -0,0 +1,6 @@ +docker stop bi150_tc +docker rm bi150_tc + +docker build . -t bi150_text_classification + +docker run --rm -p 17777:8000 -v $mountpath:/model:ro -it --device=/dev/iluvatar0:/dev/iluvatar0 --name bi150_tc -e CONFIG_JSON='{"model_class": "AutoModelForSequenceClassification", "tokenizer_class": "AutoTokenizer", "torch_dtype": "auto"}' bi150_text_classification diff --git a/logger.py b/logger.py new file mode 100644 index 0000000..d238e70 --- /dev/null +++ b/logger.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- +import logging +import os + +logging.basicConfig( + format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO"), +) + +def get_logger(file): + return logging.getLogger(file) + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..97dc7cd --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +fastapi +uvicorn diff --git a/test.py b/test.py new file mode 100644 index 0000000..ea1bd4f --- /dev/null +++ b/test.py @@ -0,0 +1,25 @@ +import argparse +import json + +import requests + + +def main() -> None: + parser = argparse.ArgumentParser(description="Send a test request to /v1/classify") + parser.add_argument("--url", default="http://127.0.0.1:8000/v1/classify") + parser.add_argument("--text", default="这家餐厅的服务特别周到,菜也很好吃,我下次还会再来。") + parser.add_argument("--timeout", type=int, default=60) + args = parser.parse_args() + + response = requests.post( + args.url, + json={"text": args.text}, + timeout=args.timeout, + ) + response.raise_for_status() + print(json.dumps(response.json(), ensure_ascii=False, indent=2)) + + +if __name__ == "__main__": + main() + diff --git a/transformers_server.py b/transformers_server.py new file mode 100644 index 0000000..13dcc98 --- /dev/null +++ b/transformers_server.py @@ -0,0 +1,129 @@ +import json +import os + +from fastapi import FastAPI +from pydantic import BaseModel +import uvicorn + +import torch +import transformers + +import logger +logger = logger.get_logger(__file__) + +app = FastAPI() +initialized = False +tokenizer = None +model = None +device = None + + +def load_config(): + raw_config = os.environ.get("CONFIG_JSON", "").strip() + if not raw_config: + return {} + + try: + config = json.loads(raw_config) + except json.JSONDecodeError as exc: + raise ValueError("CONFIG_JSON is not valid JSON") from exc + + if not isinstance(config, dict): + raise ValueError("CONFIG_JSON must be a JSON object") + + return config + + +def resolve_transformers_class(class_name, default_name): + resolved_name = class_name or default_name + resolved_class = getattr(transformers, resolved_name, None) + if resolved_class is None: + raise ValueError(f"Unsupported transformers class: {resolved_name}") + return resolved_name, resolved_class + + +def resolve_torch_dtype(dtype_name, default_name="float16"): + resolved_name = dtype_name or default_name + if resolved_name == "auto": + return resolved_name, "auto" + resolved_dtype = getattr(torch, resolved_name, None) + if resolved_dtype is None: + raise ValueError(f"Unsupported torch dtype: {resolved_name}") + return resolved_name, resolved_dtype + + +class ClassifyRequest(BaseModel): + text: str + +@app.get("/") +def read_root(): + return {"message": "Hello, World!"} + +@app.on_event("startup") +def load_model(): + logger.info("loading model") + global initialized, tokenizer, model, device + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required but is not available") + + device = torch.device("cuda") + + model_path = "/model" + config = load_config() + + tokenizer_class_name, tokenizer_class = resolve_transformers_class( + config.get("tokenizer_class"), + "AutoTokenizer", + ) + model_class_name, model_class = resolve_transformers_class( + config.get("model_class"), + "AutoModelForSequenceClassification", + ) + torch_dtype_name, torch_dtype = resolve_torch_dtype(config.get("torch_dtype")) + + logger.info( + "resolved config: " + f"model_class={model_class_name}, " + f"tokenizer_class={tokenizer_class_name}, " + f"torch_dtype={torch_dtype_name}" + ) + + tokenizer = tokenizer_class.from_pretrained(model_path) + model = model_class.from_pretrained( + model_path, + torch_dtype=torch_dtype, + ) + model = model.to(device) + model.eval() + + initialized = True + logger.info("model loaded successfully") + +@app.get("/v1/models") +async def get_status(): + logger.info(f"get status, initialized={initialized}") + return initialized + +@app.post("/v1/classify") +async def classify(request: ClassifyRequest): + text = request.text + inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) + inputs = {key: value.to(device) for key, value in inputs.items()} + with torch.no_grad(): + outputs = model(**inputs) + + predicted_class_id = outputs.logits.argmax(dim=-1).item() + id2label = model.config.id2label + predicted_label = id2label.get(predicted_class_id, str(predicted_class_id)) + + logger.info(f"text: {text}") + logger.info(f"predicted_class_id: {predicted_class_id}") + logger.info(f"predicted_label: {predicted_label}") + + return {"label": predicted_label} + + +if __name__ == '__main__': + uvicorn.run("transformers_server:app", host="0.0.0.0", port=8000, workers=1, access_log=False) +