init project files
This commit is contained in:
12
Dockerfile
Normal file
12
Dockerfile
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
FROM git.modelhub.org.cn:9443/enginex-iluvatar-bi150/vllm:0.8.3
|
||||||
|
|
||||||
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
COPY requirements.txt /workspace
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
|
|
||||||
|
ADD . /workspace
|
||||||
|
|
||||||
|
EXPOSE 80
|
||||||
|
CMD ["sh","-c","python3 transformers_server.py"]
|
||||||
|
|
||||||
BIN
cats_image.jpeg
Normal file
BIN
cats_image.jpeg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 169 KiB |
6
docker.sh
Executable file
6
docker.sh
Executable file
@@ -0,0 +1,6 @@
|
|||||||
|
docker stop bi150_ic
|
||||||
|
docker rm bi150_ic
|
||||||
|
|
||||||
|
docker build . -t bi150_image_classification
|
||||||
|
|
||||||
|
docker run -p 17777:8000 -v /mnt/contest_ceph/aiyueqi/image_classification/microsoft/resnet-50/:/model:ro -it --device=/dev/iluvatar0:/dev/iluvatar0 --name bi150_ic -e CONFIG_JSON='{"model_class": "AutoModelForImageClassification", "processer": "AutoImageProcessor", "torch_dtype": "auto"}' bi150_image_classification
|
||||||
13
logger.py
Normal file
13
logger.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
level=os.environ.get("LOGLEVEL", "INFO"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_logger(file):
|
||||||
|
return logging.getLogger(file)
|
||||||
|
|
||||||
2
requirements.txt
Normal file
2
requirements.txt
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
fastapi
|
||||||
|
uvicorn
|
||||||
128
transformers_server.py
Normal file
128
transformers_server.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
import logger
|
||||||
|
logger = logger.get_logger(__file__)
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
initialized = False
|
||||||
|
processor = None
|
||||||
|
model = None
|
||||||
|
device = None
|
||||||
|
|
||||||
|
class ClassifyRequest(BaseModel):
|
||||||
|
image: str
|
||||||
|
image_name: str | None = None
|
||||||
|
|
||||||
|
def load_config():
|
||||||
|
default_config = {
|
||||||
|
"model_class": "AutoModelForImageClassification",
|
||||||
|
"processer_class": "AutoImageProcessor",
|
||||||
|
"torch_dtype": "auto",
|
||||||
|
}
|
||||||
|
|
||||||
|
raw_config = os.getenv("CONFIG_JSON")
|
||||||
|
if not raw_config:
|
||||||
|
return default_config
|
||||||
|
|
||||||
|
try:
|
||||||
|
user_config = json.loads(raw_config)
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
raise ValueError("CONFIG_JSON is not valid JSON") from exc
|
||||||
|
|
||||||
|
if not isinstance(user_config, dict):
|
||||||
|
raise ValueError("CONFIG_JSON must decode to an object")
|
||||||
|
|
||||||
|
default_config.update(user_config)
|
||||||
|
if "processor_class" in user_config and "processer_class" not in user_config:
|
||||||
|
default_config["processer_class"] = user_config["processor_class"]
|
||||||
|
|
||||||
|
return default_config
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_transformers_class(class_name):
|
||||||
|
try:
|
||||||
|
return getattr(transformers, class_name)
|
||||||
|
except AttributeError as exc:
|
||||||
|
raise ValueError(f"unsupported transformers class: {class_name}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_torch_dtype(dtype_name):
|
||||||
|
if dtype_name == "auto":
|
||||||
|
return "auto"
|
||||||
|
|
||||||
|
try:
|
||||||
|
return getattr(torch, dtype_name)
|
||||||
|
except AttributeError as exc:
|
||||||
|
raise ValueError(f"unsupported torch dtype: {dtype_name}") from exc
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
def read_root():
|
||||||
|
return {"message": "Hello, World!"}
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
def load_model():
|
||||||
|
logger.info("loading model")
|
||||||
|
global initialized, processor, model, device
|
||||||
|
|
||||||
|
config = load_config()
|
||||||
|
processor_class = resolve_transformers_class(config["processer_class"])
|
||||||
|
model_class = resolve_transformers_class(config["model_class"])
|
||||||
|
torch_dtype = resolve_torch_dtype(config["torch_dtype"])
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"model config: model_class={config['model_class']}, "
|
||||||
|
f"processer_class={config['processer_class']}, torch_dtype={config['torch_dtype']}, "
|
||||||
|
f"device={device}"
|
||||||
|
)
|
||||||
|
|
||||||
|
processor = processor_class.from_pretrained("/model")
|
||||||
|
model = model_class.from_pretrained("/model", torch_dtype=torch_dtype)
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
initialized = True
|
||||||
|
logger.info("model loaded successfully")
|
||||||
|
|
||||||
|
@app.get("/v1/models")
|
||||||
|
async def get_status():
|
||||||
|
logger.info(f"get status, initialized={initialized}")
|
||||||
|
return initialized
|
||||||
|
|
||||||
|
@app.post("/v1/classify")
|
||||||
|
async def classify(request: ClassifyRequest):
|
||||||
|
if not initialized or processor is None or model is None or device is None:
|
||||||
|
raise HTTPException(status_code=503, detail="model is not initialized")
|
||||||
|
|
||||||
|
try:
|
||||||
|
image_bytes = base64.b64decode(request.image)
|
||||||
|
image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("failed to decode input image")
|
||||||
|
raise HTTPException(status_code=400, detail="invalid image payload") from exc
|
||||||
|
|
||||||
|
try:
|
||||||
|
inputs = processor(images=image, return_tensors="pt")
|
||||||
|
inputs = {key: value.to(device) for key, value in inputs.items()}
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(**inputs).logits
|
||||||
|
top5 = torch.topk(logits, k=5, dim=-1).indices[0].tolist()
|
||||||
|
logger.info(f"classify image_name={request.image_name}, labels={top5}")
|
||||||
|
return {"labels": top5}
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("image classification failed")
|
||||||
|
raise HTTPException(status_code=500, detail="classification failed") from exc
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
uvicorn.run("transformers_server:app", host="0.0.0.0", port=8000, workers=1, access_log=False)
|
||||||
Reference in New Issue
Block a user