init project
This commit is contained in:
129
transformers_server.py
Normal file
129
transformers_server.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
import logger
|
||||
logger = logger.get_logger(__file__)
|
||||
|
||||
app = FastAPI()
|
||||
initialized = False
|
||||
tokenizer = None
|
||||
model = None
|
||||
device = None
|
||||
|
||||
|
||||
def load_config():
|
||||
raw_config = os.environ.get("CONFIG_JSON", "").strip()
|
||||
if not raw_config:
|
||||
return {}
|
||||
|
||||
try:
|
||||
config = json.loads(raw_config)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError("CONFIG_JSON is not valid JSON") from exc
|
||||
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError("CONFIG_JSON must be a JSON object")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def resolve_transformers_class(class_name, default_name):
|
||||
resolved_name = class_name or default_name
|
||||
resolved_class = getattr(transformers, resolved_name, None)
|
||||
if resolved_class is None:
|
||||
raise ValueError(f"Unsupported transformers class: {resolved_name}")
|
||||
return resolved_name, resolved_class
|
||||
|
||||
|
||||
def resolve_torch_dtype(dtype_name, default_name="float16"):
|
||||
resolved_name = dtype_name or default_name
|
||||
if resolved_name == "auto":
|
||||
return resolved_name, "auto"
|
||||
resolved_dtype = getattr(torch, resolved_name, None)
|
||||
if resolved_dtype is None:
|
||||
raise ValueError(f"Unsupported torch dtype: {resolved_name}")
|
||||
return resolved_name, resolved_dtype
|
||||
|
||||
|
||||
class ClassifyRequest(BaseModel):
|
||||
text: str
|
||||
|
||||
@app.get("/")
|
||||
def read_root():
|
||||
return {"message": "Hello, World!"}
|
||||
|
||||
@app.on_event("startup")
|
||||
def load_model():
|
||||
logger.info("loading model")
|
||||
global initialized, tokenizer, model, device
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA is required but is not available")
|
||||
|
||||
device = torch.device("cuda")
|
||||
|
||||
model_path = "/model"
|
||||
config = load_config()
|
||||
|
||||
tokenizer_class_name, tokenizer_class = resolve_transformers_class(
|
||||
config.get("tokenizer_class"),
|
||||
"AutoTokenizer",
|
||||
)
|
||||
model_class_name, model_class = resolve_transformers_class(
|
||||
config.get("model_class"),
|
||||
"AutoModelForSequenceClassification",
|
||||
)
|
||||
torch_dtype_name, torch_dtype = resolve_torch_dtype(config.get("torch_dtype"))
|
||||
|
||||
logger.info(
|
||||
"resolved config: "
|
||||
f"model_class={model_class_name}, "
|
||||
f"tokenizer_class={tokenizer_class_name}, "
|
||||
f"torch_dtype={torch_dtype_name}"
|
||||
)
|
||||
|
||||
tokenizer = tokenizer_class.from_pretrained(model_path)
|
||||
model = model_class.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
initialized = True
|
||||
logger.info("model loaded successfully")
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def get_status():
|
||||
logger.info(f"get status, initialized={initialized}")
|
||||
return initialized
|
||||
|
||||
@app.post("/v1/classify")
|
||||
async def classify(request: ClassifyRequest):
|
||||
text = request.text
|
||||
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
||||
inputs = {key: value.to(device) for key, value in inputs.items()}
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
predicted_class_id = outputs.logits.argmax(dim=-1).item()
|
||||
id2label = model.config.id2label
|
||||
predicted_label = id2label.get(predicted_class_id, str(predicted_class_id))
|
||||
|
||||
logger.info(f"text: {text}")
|
||||
logger.info(f"predicted_class_id: {predicted_class_id}")
|
||||
logger.info(f"predicted_label: {predicted_label}")
|
||||
|
||||
return {"label": predicted_label}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
uvicorn.run("transformers_server:app", host="0.0.0.0", port=8000, workers=1, access_log=False)
|
||||
|
||||
Reference in New Issue
Block a user