diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..6d86dd6 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,12 @@ +FROM git.modelhub.org.cn:9443/enginex-iluvatar-bi150/vllm:0.8.3 + +WORKDIR /workspace + +COPY requirements.txt /workspace +RUN pip install --no-cache-dir -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple + +ADD . /workspace + +EXPOSE 80 +CMD ["sh","-c","python3 transformers_server.py"] + diff --git a/cats_image.jpeg b/cats_image.jpeg new file mode 100644 index 0000000..e131e8e Binary files /dev/null and b/cats_image.jpeg differ diff --git a/docker.sh b/docker.sh new file mode 100755 index 0000000..8e8c70c --- /dev/null +++ b/docker.sh @@ -0,0 +1,6 @@ +docker stop bi150_ic +docker rm bi150_ic + +docker build . -t bi150_image_classification + +docker run -p 17777:8000 -v /mnt/contest_ceph/aiyueqi/image_classification/microsoft/resnet-50/:/model:ro -it --device=/dev/iluvatar0:/dev/iluvatar0 --name bi150_ic -e CONFIG_JSON='{"model_class": "AutoModelForImageClassification", "processer": "AutoImageProcessor", "torch_dtype": "auto"}' bi150_image_classification diff --git a/logger.py b/logger.py new file mode 100644 index 0000000..d238e70 --- /dev/null +++ b/logger.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- +import logging +import os + +logging.basicConfig( + format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO"), +) + +def get_logger(file): + return logging.getLogger(file) + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..97dc7cd --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +fastapi +uvicorn diff --git a/transformers_server.py b/transformers_server.py new file mode 100644 index 0000000..ebdb7af --- /dev/null +++ b/transformers_server.py @@ -0,0 +1,128 @@ +import base64 +import json +import os +from io import BytesIO + +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +from PIL import Image +import torch +import transformers +from transformers import AutoImageProcessor, AutoModelForImageClassification +import uvicorn + +import logger +logger = logger.get_logger(__file__) + +app = FastAPI() +initialized = False +processor = None +model = None +device = None + +class ClassifyRequest(BaseModel): + image: str + image_name: str | None = None + +def load_config(): + default_config = { + "model_class": "AutoModelForImageClassification", + "processer_class": "AutoImageProcessor", + "torch_dtype": "auto", + } + + raw_config = os.getenv("CONFIG_JSON") + if not raw_config: + return default_config + + try: + user_config = json.loads(raw_config) + except json.JSONDecodeError as exc: + raise ValueError("CONFIG_JSON is not valid JSON") from exc + + if not isinstance(user_config, dict): + raise ValueError("CONFIG_JSON must decode to an object") + + default_config.update(user_config) + if "processor_class" in user_config and "processer_class" not in user_config: + default_config["processer_class"] = user_config["processor_class"] + + return default_config + + +def resolve_transformers_class(class_name): + try: + return getattr(transformers, class_name) + except AttributeError as exc: + raise ValueError(f"unsupported transformers class: {class_name}") from exc + + +def resolve_torch_dtype(dtype_name): + if dtype_name == "auto": + return "auto" + + try: + return getattr(torch, dtype_name) + except AttributeError as exc: + raise ValueError(f"unsupported torch dtype: {dtype_name}") from exc + +@app.get("/") +def read_root(): + return {"message": "Hello, World!"} + +@app.on_event("startup") +def load_model(): + logger.info("loading model") + global initialized, processor, model, device + + config = load_config() + processor_class = resolve_transformers_class(config["processer_class"]) + model_class = resolve_transformers_class(config["model_class"]) + torch_dtype = resolve_torch_dtype(config["torch_dtype"]) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + logger.info( + f"model config: model_class={config['model_class']}, " + f"processer_class={config['processer_class']}, torch_dtype={config['torch_dtype']}, " + f"device={device}" + ) + + processor = processor_class.from_pretrained("/model") + model = model_class.from_pretrained("/model", torch_dtype=torch_dtype) + model.to(device) + model.eval() + + initialized = True + logger.info("model loaded successfully") + +@app.get("/v1/models") +async def get_status(): + logger.info(f"get status, initialized={initialized}") + return initialized + +@app.post("/v1/classify") +async def classify(request: ClassifyRequest): + if not initialized or processor is None or model is None or device is None: + raise HTTPException(status_code=503, detail="model is not initialized") + + try: + image_bytes = base64.b64decode(request.image) + image = Image.open(BytesIO(image_bytes)).convert("RGB") + except Exception as exc: + logger.exception("failed to decode input image") + raise HTTPException(status_code=400, detail="invalid image payload") from exc + + try: + inputs = processor(images=image, return_tensors="pt") + inputs = {key: value.to(device) for key, value in inputs.items()} + with torch.no_grad(): + logits = model(**inputs).logits + top5 = torch.topk(logits, k=5, dim=-1).indices[0].tolist() + logger.info(f"classify image_name={request.image_name}, labels={top5}") + return {"labels": top5} + except Exception as exc: + logger.exception("image classification failed") + raise HTTPException(status_code=500, detail="classification failed") from exc + +if __name__ == '__main__': + uvicorn.run("transformers_server:app", host="0.0.0.0", port=8000, workers=1, access_log=False)