import json import os from fastapi import FastAPI from pydantic import BaseModel import uvicorn import torch import transformers import logger logger = logger.get_logger(__file__) app = FastAPI() initialized = False tokenizer = None model = None device = None def load_config(): raw_config = os.environ.get("CONFIG_JSON", "").strip() if not raw_config: return {} try: config = json.loads(raw_config) except json.JSONDecodeError as exc: raise ValueError("CONFIG_JSON is not valid JSON") from exc if not isinstance(config, dict): raise ValueError("CONFIG_JSON must be a JSON object") return config def resolve_transformers_class(class_name, default_name): resolved_name = class_name or default_name resolved_class = getattr(transformers, resolved_name, None) if resolved_class is None: raise ValueError(f"Unsupported transformers class: {resolved_name}") return resolved_name, resolved_class def resolve_torch_dtype(dtype_name, default_name="float16"): resolved_name = dtype_name or default_name if resolved_name == "auto": return resolved_name, "auto" resolved_dtype = getattr(torch, resolved_name, None) if resolved_dtype is None: raise ValueError(f"Unsupported torch dtype: {resolved_name}") return resolved_name, resolved_dtype class ClassifyRequest(BaseModel): text: str @app.get("/") def read_root(): return {"message": "Hello, World!"} @app.on_event("startup") def load_model(): logger.info("loading model") global initialized, tokenizer, model, device if not torch.cuda.is_available(): raise RuntimeError("CUDA is required but is not available") device = torch.device("cuda") model_path = "/model" config = load_config() tokenizer_class_name, tokenizer_class = resolve_transformers_class( config.get("tokenizer_class"), "AutoTokenizer", ) model_class_name, model_class = resolve_transformers_class( config.get("model_class"), "AutoModelForSequenceClassification", ) torch_dtype_name, torch_dtype = resolve_torch_dtype(config.get("torch_dtype")) logger.info( "resolved config: " f"model_class={model_class_name}, " f"tokenizer_class={tokenizer_class_name}, " f"torch_dtype={torch_dtype_name}" ) tokenizer = tokenizer_class.from_pretrained(model_path) model = model_class.from_pretrained( model_path, torch_dtype=torch_dtype, ) model = model.to(device) model.eval() initialized = True logger.info("model loaded successfully") @app.get("/v1/models") async def get_status(): logger.info(f"get status, initialized={initialized}") return initialized @app.post("/v1/classify") async def classify(request: ClassifyRequest): text = request.text inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) inputs = {key: value.to(device) for key, value in inputs.items()} with torch.no_grad(): outputs = model(**inputs) predicted_class_id = outputs.logits.argmax(dim=-1).item() id2label = model.config.id2label predicted_label = id2label.get(predicted_class_id, str(predicted_class_id)) logger.info(f"text: {text}") logger.info(f"predicted_class_id: {predicted_class_id}") logger.info(f"predicted_label: {predicted_label}") return {"label": predicted_label} if __name__ == '__main__': uvicorn.run("transformers_server:app", host="0.0.0.0", port=8000, workers=1, access_log=False)