import base64 import json import os from io import BytesIO from fastapi import FastAPI, HTTPException from pydantic import BaseModel from PIL import Image import torch import transformers from transformers import AutoImageProcessor, AutoModelForImageClassification import uvicorn import logger logger = logger.get_logger(__file__) app = FastAPI() initialized = False processor = None model = None device = None class ClassifyRequest(BaseModel): image: str image_name: str | None = None def load_config(): default_config = { "model_class": "AutoModelForImageClassification", "processer_class": "AutoImageProcessor", "torch_dtype": "auto", } raw_config = os.getenv("CONFIG_JSON") if not raw_config: return default_config try: user_config = json.loads(raw_config) except json.JSONDecodeError as exc: raise ValueError("CONFIG_JSON is not valid JSON") from exc if not isinstance(user_config, dict): raise ValueError("CONFIG_JSON must decode to an object") default_config.update(user_config) if "processor_class" in user_config and "processer_class" not in user_config: default_config["processer_class"] = user_config["processor_class"] return default_config def resolve_transformers_class(class_name): try: return getattr(transformers, class_name) except AttributeError as exc: raise ValueError(f"unsupported transformers class: {class_name}") from exc def resolve_torch_dtype(dtype_name): if dtype_name == "auto": return "auto" try: return getattr(torch, dtype_name) except AttributeError as exc: raise ValueError(f"unsupported torch dtype: {dtype_name}") from exc @app.get("/") def read_root(): return {"message": "Hello, World!"} @app.on_event("startup") def load_model(): logger.info("loading model") global initialized, processor, model, device config = load_config() processor_class = resolve_transformers_class(config["processer_class"]) model_class = resolve_transformers_class(config["model_class"]) torch_dtype = resolve_torch_dtype(config["torch_dtype"]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info( f"model config: model_class={config['model_class']}, " f"processer_class={config['processer_class']}, torch_dtype={config['torch_dtype']}, " f"device={device}" ) processor = processor_class.from_pretrained("/model") model = model_class.from_pretrained("/model", torch_dtype=torch_dtype) model.to(device) model.eval() initialized = True logger.info("model loaded successfully") @app.get("/v1/models") async def get_status(): logger.info(f"get status, initialized={initialized}") return initialized @app.post("/v1/classify") async def classify(request: ClassifyRequest): if not initialized or processor is None or model is None or device is None: raise HTTPException(status_code=503, detail="model is not initialized") try: image_bytes = base64.b64decode(request.image) image = Image.open(BytesIO(image_bytes)).convert("RGB") except Exception as exc: logger.exception("failed to decode input image") raise HTTPException(status_code=400, detail="invalid image payload") from exc try: inputs = processor(images=image, return_tensors="pt") inputs = {key: value.to(device) for key, value in inputs.items()} with torch.no_grad(): logits = model(**inputs).logits top5 = torch.topk(logits, k=5, dim=-1).indices[0].tolist() logger.info(f"classify image_name={request.image_name}, labels={top5}") return {"labels": top5} except Exception as exc: logger.exception("image classification failed") raise HTTPException(status_code=500, detail="classification failed") from exc if __name__ == '__main__': uvicorn.run("transformers_server:app", host="0.0.0.0", port=8000, workers=1, access_log=False)