Files
enginex-bi_150-text-classif…/transformers_server.py
2026-04-10 10:05:19 +00:00

130 lines
3.5 KiB
Python

import json
import os
from fastapi import FastAPI
from pydantic import BaseModel
import uvicorn
import torch
import transformers
import logger
logger = logger.get_logger(__file__)
app = FastAPI()
initialized = False
tokenizer = None
model = None
device = None
def load_config():
raw_config = os.environ.get("CONFIG_JSON", "").strip()
if not raw_config:
return {}
try:
config = json.loads(raw_config)
except json.JSONDecodeError as exc:
raise ValueError("CONFIG_JSON is not valid JSON") from exc
if not isinstance(config, dict):
raise ValueError("CONFIG_JSON must be a JSON object")
return config
def resolve_transformers_class(class_name, default_name):
resolved_name = class_name or default_name
resolved_class = getattr(transformers, resolved_name, None)
if resolved_class is None:
raise ValueError(f"Unsupported transformers class: {resolved_name}")
return resolved_name, resolved_class
def resolve_torch_dtype(dtype_name, default_name="float16"):
resolved_name = dtype_name or default_name
if resolved_name == "auto":
return resolved_name, "auto"
resolved_dtype = getattr(torch, resolved_name, None)
if resolved_dtype is None:
raise ValueError(f"Unsupported torch dtype: {resolved_name}")
return resolved_name, resolved_dtype
class ClassifyRequest(BaseModel):
text: str
@app.get("/")
def read_root():
return {"message": "Hello, World!"}
@app.on_event("startup")
def load_model():
logger.info("loading model")
global initialized, tokenizer, model, device
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required but is not available")
device = torch.device("cuda")
model_path = "/model"
config = load_config()
tokenizer_class_name, tokenizer_class = resolve_transformers_class(
config.get("tokenizer_class"),
"AutoTokenizer",
)
model_class_name, model_class = resolve_transformers_class(
config.get("model_class"),
"AutoModelForSequenceClassification",
)
torch_dtype_name, torch_dtype = resolve_torch_dtype(config.get("torch_dtype"))
logger.info(
"resolved config: "
f"model_class={model_class_name}, "
f"tokenizer_class={tokenizer_class_name}, "
f"torch_dtype={torch_dtype_name}"
)
tokenizer = tokenizer_class.from_pretrained(model_path)
model = model_class.from_pretrained(
model_path,
torch_dtype=torch_dtype,
)
model = model.to(device)
model.eval()
initialized = True
logger.info("model loaded successfully")
@app.get("/v1/models")
async def get_status():
logger.info(f"get status, initialized={initialized}")
return initialized
@app.post("/v1/classify")
async def classify(request: ClassifyRequest):
text = request.text
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
inputs = {key: value.to(device) for key, value in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
predicted_class_id = outputs.logits.argmax(dim=-1).item()
id2label = model.config.id2label
predicted_label = id2label.get(predicted_class_id, str(predicted_class_id))
logger.info(f"text: {text}")
logger.info(f"predicted_class_id: {predicted_class_id}")
logger.info(f"predicted_label: {predicted_label}")
return {"label": predicted_label}
if __name__ == '__main__':
uvicorn.run("transformers_server:app", host="0.0.0.0", port=8000, workers=1, access_log=False)