Files
enginex-bi_series-image-cla…/transformers_server.py

129 lines
4.0 KiB
Python
Raw Normal View History

2026-04-08 06:16:35 +00:00
import base64
import json
import os
from io import BytesIO
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from PIL import Image
import torch
import transformers
from transformers import AutoImageProcessor, AutoModelForImageClassification
import uvicorn
import logger
logger = logger.get_logger(__file__)
app = FastAPI()
initialized = False
processor = None
model = None
device = None
class ClassifyRequest(BaseModel):
image: str
image_name: str | None = None
def load_config():
default_config = {
"model_class": "AutoModelForImageClassification",
"processer_class": "AutoImageProcessor",
"torch_dtype": "auto",
}
raw_config = os.getenv("CONFIG_JSON")
if not raw_config:
return default_config
try:
user_config = json.loads(raw_config)
except json.JSONDecodeError as exc:
raise ValueError("CONFIG_JSON is not valid JSON") from exc
if not isinstance(user_config, dict):
raise ValueError("CONFIG_JSON must decode to an object")
default_config.update(user_config)
if "processor_class" in user_config and "processer_class" not in user_config:
default_config["processer_class"] = user_config["processor_class"]
return default_config
def resolve_transformers_class(class_name):
try:
return getattr(transformers, class_name)
except AttributeError as exc:
raise ValueError(f"unsupported transformers class: {class_name}") from exc
def resolve_torch_dtype(dtype_name):
if dtype_name == "auto":
return "auto"
try:
return getattr(torch, dtype_name)
except AttributeError as exc:
raise ValueError(f"unsupported torch dtype: {dtype_name}") from exc
@app.get("/")
def read_root():
return {"message": "Hello, World!"}
@app.on_event("startup")
def load_model():
logger.info("loading model")
global initialized, processor, model, device
config = load_config()
processor_class = resolve_transformers_class(config["processer_class"])
model_class = resolve_transformers_class(config["model_class"])
torch_dtype = resolve_torch_dtype(config["torch_dtype"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(
f"model config: model_class={config['model_class']}, "
f"processer_class={config['processer_class']}, torch_dtype={config['torch_dtype']}, "
f"device={device}"
)
processor = processor_class.from_pretrained("/model")
model = model_class.from_pretrained("/model", torch_dtype=torch_dtype)
model.to(device)
model.eval()
initialized = True
logger.info("model loaded successfully")
@app.get("/v1/models")
async def get_status():
logger.info(f"get status, initialized={initialized}")
return initialized
@app.post("/v1/classify")
async def classify(request: ClassifyRequest):
if not initialized or processor is None or model is None or device is None:
raise HTTPException(status_code=503, detail="model is not initialized")
try:
image_bytes = base64.b64decode(request.image)
image = Image.open(BytesIO(image_bytes)).convert("RGB")
except Exception as exc:
logger.exception("failed to decode input image")
raise HTTPException(status_code=400, detail="invalid image payload") from exc
try:
inputs = processor(images=image, return_tensors="pt")
inputs = {key: value.to(device) for key, value in inputs.items()}
with torch.no_grad():
logits = model(**inputs).logits
top5 = torch.topk(logits, k=5, dim=-1).indices[0].tolist()
logger.info(f"classify image_name={request.image_name}, labels={top5}")
return {"labels": top5}
except Exception as exc:
logger.exception("image classification failed")
raise HTTPException(status_code=500, detail="classification failed") from exc
if __name__ == '__main__':
uvicorn.run("transformers_server:app", host="0.0.0.0", port=8000, workers=1, access_log=False)