129 lines
4.0 KiB
Python
129 lines
4.0 KiB
Python
import base64
|
|
import json
|
|
import os
|
|
from io import BytesIO
|
|
|
|
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel
|
|
from PIL import Image
|
|
import torch
|
|
import transformers
|
|
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
|
import uvicorn
|
|
|
|
import logger
|
|
logger = logger.get_logger(__file__)
|
|
|
|
app = FastAPI()
|
|
initialized = False
|
|
processor = None
|
|
model = None
|
|
device = None
|
|
|
|
class ClassifyRequest(BaseModel):
|
|
image: str
|
|
image_name: str | None = None
|
|
|
|
def load_config():
|
|
default_config = {
|
|
"model_class": "AutoModelForImageClassification",
|
|
"processer_class": "AutoImageProcessor",
|
|
"torch_dtype": "auto",
|
|
}
|
|
|
|
raw_config = os.getenv("CONFIG_JSON")
|
|
if not raw_config:
|
|
return default_config
|
|
|
|
try:
|
|
user_config = json.loads(raw_config)
|
|
except json.JSONDecodeError as exc:
|
|
raise ValueError("CONFIG_JSON is not valid JSON") from exc
|
|
|
|
if not isinstance(user_config, dict):
|
|
raise ValueError("CONFIG_JSON must decode to an object")
|
|
|
|
default_config.update(user_config)
|
|
if "processor_class" in user_config and "processer_class" not in user_config:
|
|
default_config["processer_class"] = user_config["processor_class"]
|
|
|
|
return default_config
|
|
|
|
|
|
def resolve_transformers_class(class_name):
|
|
try:
|
|
return getattr(transformers, class_name)
|
|
except AttributeError as exc:
|
|
raise ValueError(f"unsupported transformers class: {class_name}") from exc
|
|
|
|
|
|
def resolve_torch_dtype(dtype_name):
|
|
if dtype_name == "auto":
|
|
return "auto"
|
|
|
|
try:
|
|
return getattr(torch, dtype_name)
|
|
except AttributeError as exc:
|
|
raise ValueError(f"unsupported torch dtype: {dtype_name}") from exc
|
|
|
|
@app.get("/")
|
|
def read_root():
|
|
return {"message": "Hello, World!"}
|
|
|
|
@app.on_event("startup")
|
|
def load_model():
|
|
logger.info("loading model")
|
|
global initialized, processor, model, device
|
|
|
|
config = load_config()
|
|
processor_class = resolve_transformers_class(config["processer_class"])
|
|
model_class = resolve_transformers_class(config["model_class"])
|
|
torch_dtype = resolve_torch_dtype(config["torch_dtype"])
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
logger.info(
|
|
f"model config: model_class={config['model_class']}, "
|
|
f"processer_class={config['processer_class']}, torch_dtype={config['torch_dtype']}, "
|
|
f"device={device}"
|
|
)
|
|
|
|
processor = processor_class.from_pretrained("/model")
|
|
model = model_class.from_pretrained("/model", torch_dtype=torch_dtype)
|
|
model.to(device)
|
|
model.eval()
|
|
|
|
initialized = True
|
|
logger.info("model loaded successfully")
|
|
|
|
@app.get("/v1/models")
|
|
async def get_status():
|
|
logger.info(f"get status, initialized={initialized}")
|
|
return initialized
|
|
|
|
@app.post("/v1/classify")
|
|
async def classify(request: ClassifyRequest):
|
|
if not initialized or processor is None or model is None or device is None:
|
|
raise HTTPException(status_code=503, detail="model is not initialized")
|
|
|
|
try:
|
|
image_bytes = base64.b64decode(request.image)
|
|
image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
|
except Exception as exc:
|
|
logger.exception("failed to decode input image")
|
|
raise HTTPException(status_code=400, detail="invalid image payload") from exc
|
|
|
|
try:
|
|
inputs = processor(images=image, return_tensors="pt")
|
|
inputs = {key: value.to(device) for key, value in inputs.items()}
|
|
with torch.no_grad():
|
|
logits = model(**inputs).logits
|
|
top5 = torch.topk(logits, k=5, dim=-1).indices[0].tolist()
|
|
logger.info(f"classify image_name={request.image_name}, labels={top5}")
|
|
return {"labels": top5}
|
|
except Exception as exc:
|
|
logger.exception("image classification failed")
|
|
raise HTTPException(status_code=500, detail="classification failed") from exc
|
|
|
|
if __name__ == '__main__':
|
|
uvicorn.run("transformers_server:app", host="0.0.0.0", port=8000, workers=1, access_log=False)
|