Files
enginex-c_series-vl/server.py
2025-09-19 14:46:59 +08:00

85 lines
2.2 KiB
Python

import base64
import gc
import io
import os
import time
import uvicorn
from typing import List, Optional, Dict, Any, Tuple
import torch
from PIL import Image
from fastapi import FastAPI, HTTPException, Query
from pydantic import BaseModel
from transformers import (AutoTokenizer, AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoModel)
import logger
log = logger.get_logger(__file__)
app = FastAPI()
model_type = None
model = None
device = None
tokenizer = None
class GenParams(BaseModel):
max_new_tokens: int = 128
temperature: float = 0.0
top_p: float = 1.0
do_sample: bool = False
class InferRequest(BaseModel):
prompt: str
generation: GenParams = GenParams()
dtype: str = "auto" # "auto"|"float16"|"bfloat16"|"float32"
warmup_runs: int = 1
measure_token_times: bool = False
@app.on_event("startup")
def load_model():
log.info("loading model")
global status, device, model_type, model, tokenizer
model_path = "/model"
cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
model_type = cfg.model_type
log.info(f"model type: {model_type}")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=True)
model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float32,
device_map=None, trust_remote_code=True)
model.to("cuda")
model.eval()
status = "success"
log.info(f"model loaded successfully")
@app.post("/infer")
def infer(req: InferRequest):
image = Image.open('1.PNG').convert('RGB')
if model_type == "minicpmv":
text = handle_minicpmv(image, req.prompt, req.generation)
log.info(f"text={text}")
result = dict()
result["output_text"] = text
return result
def handle_minicpmv(image: Image.Image, prompt: str, gen: GenParams):
# Prepare msgs in the format expected by model.chat
msgs = [{"role": "user", "content": prompt}]
# Call the model's built-in chat method
response = model.chat(image=image, msgs=msgs, tokenizer=tokenizer,
sampling=gen.do_sample, temperature=gen.temperature, stream=False)
return response
if __name__ == '__main__':
uvicorn.run("server:app", host="0.0.0.0", port=8000, workers=1, access_log=False)