init project files
This commit is contained in:
12
Dockerfile
Normal file
12
Dockerfile
Normal file
@@ -0,0 +1,12 @@
|
||||
FROM git.modelhub.org.cn:9443/enginex-iluvatar-bi150/vllm:0.8.3
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
COPY requirements.txt /workspace
|
||||
RUN pip install --no-cache-dir -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
ADD . /workspace
|
||||
|
||||
EXPOSE 80
|
||||
CMD ["sh","-c","python3 transformers_server.py"]
|
||||
|
||||
BIN
cats_image.jpeg
Normal file
BIN
cats_image.jpeg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 169 KiB |
6
docker.sh
Executable file
6
docker.sh
Executable file
@@ -0,0 +1,6 @@
|
||||
docker stop bi150_ic
|
||||
docker rm bi150_ic
|
||||
|
||||
docker build . -t bi150_image_classification
|
||||
|
||||
docker run -p 17777:8000 -v /mnt/contest_ceph/aiyueqi/image_classification/microsoft/resnet-50/:/model:ro -it --device=/dev/iluvatar0:/dev/iluvatar0 --name bi150_ic -e CONFIG_JSON='{"model_class": "AutoModelForImageClassification", "processer": "AutoImageProcessor", "torch_dtype": "auto"}' bi150_image_classification
|
||||
13
logger.py
Normal file
13
logger.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import logging
|
||||
import os
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
level=os.environ.get("LOGLEVEL", "INFO"),
|
||||
)
|
||||
|
||||
def get_logger(file):
|
||||
return logging.getLogger(file)
|
||||
|
||||
2
requirements.txt
Normal file
2
requirements.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
fastapi
|
||||
uvicorn
|
||||
128
transformers_server.py
Normal file
128
transformers_server.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from io import BytesIO
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from PIL import Image
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
||||
import uvicorn
|
||||
|
||||
import logger
|
||||
logger = logger.get_logger(__file__)
|
||||
|
||||
app = FastAPI()
|
||||
initialized = False
|
||||
processor = None
|
||||
model = None
|
||||
device = None
|
||||
|
||||
class ClassifyRequest(BaseModel):
|
||||
image: str
|
||||
image_name: str | None = None
|
||||
|
||||
def load_config():
|
||||
default_config = {
|
||||
"model_class": "AutoModelForImageClassification",
|
||||
"processer_class": "AutoImageProcessor",
|
||||
"torch_dtype": "auto",
|
||||
}
|
||||
|
||||
raw_config = os.getenv("CONFIG_JSON")
|
||||
if not raw_config:
|
||||
return default_config
|
||||
|
||||
try:
|
||||
user_config = json.loads(raw_config)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError("CONFIG_JSON is not valid JSON") from exc
|
||||
|
||||
if not isinstance(user_config, dict):
|
||||
raise ValueError("CONFIG_JSON must decode to an object")
|
||||
|
||||
default_config.update(user_config)
|
||||
if "processor_class" in user_config and "processer_class" not in user_config:
|
||||
default_config["processer_class"] = user_config["processor_class"]
|
||||
|
||||
return default_config
|
||||
|
||||
|
||||
def resolve_transformers_class(class_name):
|
||||
try:
|
||||
return getattr(transformers, class_name)
|
||||
except AttributeError as exc:
|
||||
raise ValueError(f"unsupported transformers class: {class_name}") from exc
|
||||
|
||||
|
||||
def resolve_torch_dtype(dtype_name):
|
||||
if dtype_name == "auto":
|
||||
return "auto"
|
||||
|
||||
try:
|
||||
return getattr(torch, dtype_name)
|
||||
except AttributeError as exc:
|
||||
raise ValueError(f"unsupported torch dtype: {dtype_name}") from exc
|
||||
|
||||
@app.get("/")
|
||||
def read_root():
|
||||
return {"message": "Hello, World!"}
|
||||
|
||||
@app.on_event("startup")
|
||||
def load_model():
|
||||
logger.info("loading model")
|
||||
global initialized, processor, model, device
|
||||
|
||||
config = load_config()
|
||||
processor_class = resolve_transformers_class(config["processer_class"])
|
||||
model_class = resolve_transformers_class(config["model_class"])
|
||||
torch_dtype = resolve_torch_dtype(config["torch_dtype"])
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
logger.info(
|
||||
f"model config: model_class={config['model_class']}, "
|
||||
f"processer_class={config['processer_class']}, torch_dtype={config['torch_dtype']}, "
|
||||
f"device={device}"
|
||||
)
|
||||
|
||||
processor = processor_class.from_pretrained("/model")
|
||||
model = model_class.from_pretrained("/model", torch_dtype=torch_dtype)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
initialized = True
|
||||
logger.info("model loaded successfully")
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def get_status():
|
||||
logger.info(f"get status, initialized={initialized}")
|
||||
return initialized
|
||||
|
||||
@app.post("/v1/classify")
|
||||
async def classify(request: ClassifyRequest):
|
||||
if not initialized or processor is None or model is None or device is None:
|
||||
raise HTTPException(status_code=503, detail="model is not initialized")
|
||||
|
||||
try:
|
||||
image_bytes = base64.b64decode(request.image)
|
||||
image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
||||
except Exception as exc:
|
||||
logger.exception("failed to decode input image")
|
||||
raise HTTPException(status_code=400, detail="invalid image payload") from exc
|
||||
|
||||
try:
|
||||
inputs = processor(images=image, return_tensors="pt")
|
||||
inputs = {key: value.to(device) for key, value in inputs.items()}
|
||||
with torch.no_grad():
|
||||
logits = model(**inputs).logits
|
||||
top5 = torch.topk(logits, k=5, dim=-1).indices[0].tolist()
|
||||
logger.info(f"classify image_name={request.image_name}, labels={top5}")
|
||||
return {"labels": top5}
|
||||
except Exception as exc:
|
||||
logger.exception("image classification failed")
|
||||
raise HTTPException(status_code=500, detail="classification failed") from exc
|
||||
|
||||
if __name__ == '__main__':
|
||||
uvicorn.run("transformers_server:app", host="0.0.0.0", port=8000, workers=1, access_log=False)
|
||||
Reference in New Issue
Block a user