import base64 import gc import io import os import time import uvicorn from typing import List, Optional, Dict, Any, Tuple import torch from PIL import Image from fastapi import FastAPI, HTTPException, Query from pydantic import BaseModel from transformers import (AutoTokenizer, AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoModel) import logger log = logger.get_logger(__file__) app = FastAPI() model_type = None model = None device = None tokenizer = None class GenParams(BaseModel): max_new_tokens: int = 128 temperature: float = 0.0 top_p: float = 1.0 do_sample: bool = False class InferRequest(BaseModel): prompt: str generation: GenParams = GenParams() dtype: str = "auto" # "auto"|"float16"|"bfloat16"|"float32" warmup_runs: int = 1 measure_token_times: bool = False @app.on_event("startup") def load_model(): log.info("loading model") global status, device, model_type, model, tokenizer model_path = "/model" cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True) model_type = cfg.model_type log.info(f"model type: {model_type}") tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=True) model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float32, device_map=None, trust_remote_code=True) model.to("cuda") model.eval() status = "success" log.info(f"model loaded successfully") @app.post("/infer") def infer(req: InferRequest): image = Image.open('1.PNG').convert('RGB') if model_type == "minicpmv": text = handle_minicpmv(image, req.prompt, req.generation) log.info(f"text={text}") result = dict() result["output_text"] = text return result def handle_minicpmv(image: Image.Image, prompt: str, gen: GenParams): # Prepare msgs in the format expected by model.chat msgs = [{"role": "user", "content": prompt}] # Call the model's built-in chat method response = model.chat(image=image, msgs=msgs, tokenizer=tokenizer, sampling=gen.do_sample, temperature=gen.temperature, stream=False) return response if __name__ == '__main__': uvicorn.run("server:app", host="0.0.0.0", port=8000, workers=1, access_log=False)