85 lines
2.2 KiB
Python
85 lines
2.2 KiB
Python
|
|
import base64
|
||
|
|
import gc
|
||
|
|
import io
|
||
|
|
import os
|
||
|
|
import time
|
||
|
|
import uvicorn
|
||
|
|
from typing import List, Optional, Dict, Any, Tuple
|
||
|
|
|
||
|
|
import torch
|
||
|
|
|
||
|
|
from PIL import Image
|
||
|
|
from fastapi import FastAPI, HTTPException, Query
|
||
|
|
from pydantic import BaseModel
|
||
|
|
from transformers import (AutoTokenizer, AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoModel)
|
||
|
|
|
||
|
|
import logger
|
||
|
|
log = logger.get_logger(__file__)
|
||
|
|
|
||
|
|
app = FastAPI()
|
||
|
|
|
||
|
|
model_type = None
|
||
|
|
model = None
|
||
|
|
device = None
|
||
|
|
tokenizer = None
|
||
|
|
|
||
|
|
class GenParams(BaseModel):
|
||
|
|
max_new_tokens: int = 128
|
||
|
|
temperature: float = 0.0
|
||
|
|
top_p: float = 1.0
|
||
|
|
do_sample: bool = False
|
||
|
|
|
||
|
|
class InferRequest(BaseModel):
|
||
|
|
prompt: str
|
||
|
|
generation: GenParams = GenParams()
|
||
|
|
dtype: str = "auto" # "auto"|"float16"|"bfloat16"|"float32"
|
||
|
|
warmup_runs: int = 1
|
||
|
|
measure_token_times: bool = False
|
||
|
|
|
||
|
|
@app.on_event("startup")
|
||
|
|
def load_model():
|
||
|
|
log.info("loading model")
|
||
|
|
global status, device, model_type, model, tokenizer
|
||
|
|
|
||
|
|
model_path = "/model"
|
||
|
|
cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||
|
|
model_type = cfg.model_type
|
||
|
|
log.info(f"model type: {model_type}")
|
||
|
|
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=True)
|
||
|
|
|
||
|
|
model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float32,
|
||
|
|
device_map=None, trust_remote_code=True)
|
||
|
|
model.to("cuda")
|
||
|
|
model.eval()
|
||
|
|
|
||
|
|
status = "success"
|
||
|
|
log.info(f"model loaded successfully")
|
||
|
|
|
||
|
|
@app.post("/infer")
|
||
|
|
def infer(req: InferRequest):
|
||
|
|
image = Image.open('1.PNG').convert('RGB')
|
||
|
|
|
||
|
|
if model_type == "minicpmv":
|
||
|
|
text = handle_minicpmv(image, req.prompt, req.generation)
|
||
|
|
log.info(f"text={text}")
|
||
|
|
|
||
|
|
result = dict()
|
||
|
|
result["output_text"] = text
|
||
|
|
|
||
|
|
return result
|
||
|
|
|
||
|
|
def handle_minicpmv(image: Image.Image, prompt: str, gen: GenParams):
|
||
|
|
# Prepare msgs in the format expected by model.chat
|
||
|
|
msgs = [{"role": "user", "content": prompt}]
|
||
|
|
|
||
|
|
# Call the model's built-in chat method
|
||
|
|
response = model.chat(image=image, msgs=msgs, tokenizer=tokenizer,
|
||
|
|
sampling=gen.do_sample, temperature=gen.temperature, stream=False)
|
||
|
|
|
||
|
|
return response
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
uvicorn.run("server:app", host="0.0.0.0", port=8000, workers=1, access_log=False)
|
||
|
|
|