130 lines
3.5 KiB
Python
130 lines
3.5 KiB
Python
|
|
import json
|
||
|
|
import os
|
||
|
|
|
||
|
|
from fastapi import FastAPI
|
||
|
|
from pydantic import BaseModel
|
||
|
|
import uvicorn
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import transformers
|
||
|
|
|
||
|
|
import logger
|
||
|
|
logger = logger.get_logger(__file__)
|
||
|
|
|
||
|
|
app = FastAPI()
|
||
|
|
initialized = False
|
||
|
|
tokenizer = None
|
||
|
|
model = None
|
||
|
|
device = None
|
||
|
|
|
||
|
|
|
||
|
|
def load_config():
|
||
|
|
raw_config = os.environ.get("CONFIG_JSON", "").strip()
|
||
|
|
if not raw_config:
|
||
|
|
return {}
|
||
|
|
|
||
|
|
try:
|
||
|
|
config = json.loads(raw_config)
|
||
|
|
except json.JSONDecodeError as exc:
|
||
|
|
raise ValueError("CONFIG_JSON is not valid JSON") from exc
|
||
|
|
|
||
|
|
if not isinstance(config, dict):
|
||
|
|
raise ValueError("CONFIG_JSON must be a JSON object")
|
||
|
|
|
||
|
|
return config
|
||
|
|
|
||
|
|
|
||
|
|
def resolve_transformers_class(class_name, default_name):
|
||
|
|
resolved_name = class_name or default_name
|
||
|
|
resolved_class = getattr(transformers, resolved_name, None)
|
||
|
|
if resolved_class is None:
|
||
|
|
raise ValueError(f"Unsupported transformers class: {resolved_name}")
|
||
|
|
return resolved_name, resolved_class
|
||
|
|
|
||
|
|
|
||
|
|
def resolve_torch_dtype(dtype_name, default_name="float16"):
|
||
|
|
resolved_name = dtype_name or default_name
|
||
|
|
if resolved_name == "auto":
|
||
|
|
return resolved_name, "auto"
|
||
|
|
resolved_dtype = getattr(torch, resolved_name, None)
|
||
|
|
if resolved_dtype is None:
|
||
|
|
raise ValueError(f"Unsupported torch dtype: {resolved_name}")
|
||
|
|
return resolved_name, resolved_dtype
|
||
|
|
|
||
|
|
|
||
|
|
class ClassifyRequest(BaseModel):
|
||
|
|
text: str
|
||
|
|
|
||
|
|
@app.get("/")
|
||
|
|
def read_root():
|
||
|
|
return {"message": "Hello, World!"}
|
||
|
|
|
||
|
|
@app.on_event("startup")
|
||
|
|
def load_model():
|
||
|
|
logger.info("loading model")
|
||
|
|
global initialized, tokenizer, model, device
|
||
|
|
|
||
|
|
if not torch.cuda.is_available():
|
||
|
|
raise RuntimeError("CUDA is required but is not available")
|
||
|
|
|
||
|
|
device = torch.device("cuda")
|
||
|
|
|
||
|
|
model_path = "/model"
|
||
|
|
config = load_config()
|
||
|
|
|
||
|
|
tokenizer_class_name, tokenizer_class = resolve_transformers_class(
|
||
|
|
config.get("tokenizer_class"),
|
||
|
|
"AutoTokenizer",
|
||
|
|
)
|
||
|
|
model_class_name, model_class = resolve_transformers_class(
|
||
|
|
config.get("model_class"),
|
||
|
|
"AutoModelForSequenceClassification",
|
||
|
|
)
|
||
|
|
torch_dtype_name, torch_dtype = resolve_torch_dtype(config.get("torch_dtype"))
|
||
|
|
|
||
|
|
logger.info(
|
||
|
|
"resolved config: "
|
||
|
|
f"model_class={model_class_name}, "
|
||
|
|
f"tokenizer_class={tokenizer_class_name}, "
|
||
|
|
f"torch_dtype={torch_dtype_name}"
|
||
|
|
)
|
||
|
|
|
||
|
|
tokenizer = tokenizer_class.from_pretrained(model_path)
|
||
|
|
model = model_class.from_pretrained(
|
||
|
|
model_path,
|
||
|
|
torch_dtype=torch_dtype,
|
||
|
|
)
|
||
|
|
model = model.to(device)
|
||
|
|
model.eval()
|
||
|
|
|
||
|
|
initialized = True
|
||
|
|
logger.info("model loaded successfully")
|
||
|
|
|
||
|
|
@app.get("/v1/models")
|
||
|
|
async def get_status():
|
||
|
|
logger.info(f"get status, initialized={initialized}")
|
||
|
|
return initialized
|
||
|
|
|
||
|
|
@app.post("/v1/classify")
|
||
|
|
async def classify(request: ClassifyRequest):
|
||
|
|
text = request.text
|
||
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
||
|
|
inputs = {key: value.to(device) for key, value in inputs.items()}
|
||
|
|
with torch.no_grad():
|
||
|
|
outputs = model(**inputs)
|
||
|
|
|
||
|
|
predicted_class_id = outputs.logits.argmax(dim=-1).item()
|
||
|
|
id2label = model.config.id2label
|
||
|
|
predicted_label = id2label.get(predicted_class_id, str(predicted_class_id))
|
||
|
|
|
||
|
|
logger.info(f"text: {text}")
|
||
|
|
logger.info(f"predicted_class_id: {predicted_class_id}")
|
||
|
|
logger.info(f"predicted_label: {predicted_label}")
|
||
|
|
|
||
|
|
return {"label": predicted_label}
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
uvicorn.run("transformers_server:app", host="0.0.0.0", port=8000, workers=1, access_log=False)
|
||
|
|
|