@@ -1,604 +0,0 @@
import base64
import gc
import io
import os
import time
from typing import List , Optional , Dict , Any , Tuple
import torch
from PIL import Image
from fastapi import FastAPI , HTTPException
from pydantic import BaseModel
from transformers import (
AutoProcessor ,
AutoTokenizer ,
AutoConfig ,
AutoModelForCausalLM ,
AutoModelForVision2Seq , AutoModel
)
try :
from transformers import ( Qwen2VLForConditionalGeneration , Gemma3ForConditionalGeneration )
except ImportError :
pass
app = FastAPI ( title = " Unified VLM API (Transformers) " )
# ----------------------------
# Device selection & sync
# ----------------------------
def best_device ( ) - > torch . device :
# CUDA covers NVIDIA and AMD ROCm builds under torch.cuda
if torch . cuda . is_available ( ) :
return torch . device ( " cuda " )
# Intel oneAPI/XPU (ipex)
if hasattr ( torch , " xpu " ) and torch . xpu . is_available ( ) :
return torch . device ( " xpu " )
# Apple MPS
if hasattr ( torch . backends , " mps " ) and torch . backends . mps . is_available ( ) :
return torch . device ( " mps " )
return torch . device ( " cpu " )
def device_name ( device : torch . device ) - > str :
if device . type == " cuda " :
try :
return torch . cuda . get_device_name ( 0 )
except Exception :
return " CUDA device "
if device . type == " xpu " :
return " Intel XPU "
if device . type == " mps " :
return " Apple MPS "
return " CPU "
def device_total_mem_gb ( device : torch . device ) - > Optional [ float ] :
try :
if device . type == " cuda " :
return torch . cuda . get_device_properties ( 0 ) . total_memory / ( 1024 * * 3 )
# For others, memory reporting varies
return None
except Exception :
return None
def synchronize ( device : torch . device ) :
# Ensure accurate wall times for GPU work
try :
if device . type == " cuda " :
torch . cuda . synchronize ( )
elif device . type == " xpu " and hasattr ( torch , " xpu " ) :
torch . xpu . synchronize ( )
elif device . type == " mps " and hasattr ( torch , " mps " ) :
torch . mps . synchronize ( )
except Exception :
pass
# ----------------------------
# Model registry
# ----------------------------
class LoadedModel :
def __init__ ( self , model_type : str , model_path : str , model , processor , tokenizer ,
device : torch . device , dtype : torch . dtype ) :
self . model_type = model_type
self . model_path = model_path
self . model = model
self . processor = processor
self . tokenizer = tokenizer
self . device = device
self . dtype = dtype
_loaded : Dict [ str , LoadedModel ] = { }
# ----------------------------
# IO helpers
# ----------------------------
def load_image ( ref : str ) - > Image . Image :
# Accept http(s) URLs, local paths, or base64 data URLs
if ref . startswith ( " http:// " ) or ref . startswith ( " https:// " ) :
# Rely on HF file utilities only if you want offline; here use requests lazily
import requests
r = requests . get ( ref , timeout = 30 )
r . raise_for_status ( )
return Image . open ( io . BytesIO ( r . content ) ) . convert ( " RGB " )
if os . path . exists ( ref ) :
return Image . open ( ref ) . convert ( " RGB " )
# Base64
if ref . startswith ( " data:image " ) :
header , b64 = ref . split ( " , " , 1 )
return Image . open ( io . BytesIO ( base64 . b64decode ( b64 ) ) ) . convert ( " RGB " )
# Raw base64
try :
return Image . open ( io . BytesIO ( base64 . b64decode ( ref ) ) ) . convert ( " RGB " )
except Exception :
raise ValueError ( f " Unsupported image reference: { ref [ : 80 ] } ... " )
def pick_dtype ( req_dtype : str , device : torch . device ) - > torch . dtype :
if req_dtype == " float16 " :
return torch . float16
if req_dtype == " bfloat16 " :
return torch . bfloat16
if req_dtype == " float32 " :
return torch . float32
# auto
if device . type in ( " cuda " , " xpu " ) :
# bfloat16 works broadly on modern GPUs; fall back to float16 for older CUDA
try :
return torch . bfloat16 if torch . cuda . is_bf16_supported ( ) else torch . float16
except Exception :
return torch . float16
if device . type == " mps " :
return torch . float16
return torch . float32
def autocast_ctx ( device : torch . device , dtype : torch . dtype ) :
if device . type == " cpu " :
return torch . autocast ( device_type = " cpu " , dtype = dtype )
if device . type == " cuda " :
return torch . autocast ( device_type = " cuda " , dtype = dtype )
if device . type == " xpu " :
return torch . autocast ( device_type = " xpu " , dtype = dtype )
if device . type == " mps " :
return torch . autocast ( device_type = " mps " , dtype = dtype )
# fallback no-op
from contextlib import nullcontext
return nullcontext ( )
# ----------------------------
# Requests/Responses
# ----------------------------
class GenParams ( BaseModel ) :
max_new_tokens : int = 128
temperature : float = 0.0
top_p : float = 1.0
do_sample : bool = False
class InferRequest ( BaseModel ) :
model_path : str
prompt : str
images : List [ str ]
generation : GenParams = GenParams ( )
dtype : str = " auto " # "auto"|"float16"|"bfloat16"|"float32"
warmup_runs : int = 1
measure_token_times : bool = False
class InferResponse ( BaseModel ) :
output_text : str
timings_ms : Dict [ str , float ]
device : Dict [ str , Any ]
model_info : Dict [ str , Any ]
class LoadModelRequest ( BaseModel ) :
model_path : str
dtype : str = " auto "
class UnloadModelRequest ( BaseModel ) :
model_path : str
# ----------------------------
# Model loading
# ----------------------------
def resolve_model ( model_path : str , dtype_str : str ) - > LoadedModel :
if model_path in _loaded :
return _loaded [ model_path ]
dev = best_device ( )
dt = pick_dtype ( dtype_str , dev )
cfg = AutoConfig . from_pretrained ( model_path , trust_remote_code = True )
model_type = cfg . model_type
print ( f " model type detected: { model_type } , device: { dev } , dt: { dt } " )
if model_type in ( " qwen2_vl " , " qwen2-vl " ) :
print ( f " Loading Qwen2-VL using Qwen2VLForConditionalGeneration " )
model = Qwen2VLForConditionalGeneration . from_pretrained (
model_path ,
torch_dtype = dt if dt != torch . float32 else None ,
device_map = None ,
trust_remote_code = True ,
)
print ( " Loaded model class: " , type ( model ) )
processor = AutoProcessor . from_pretrained ( model_path , trust_remote_code = True )
# tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model . to ( dev )
model . eval ( )
lm = LoadedModel ( model_type , model_path , model , processor , None , dev , dt )
_loaded [ model_path ] = lm
return lm
elif model_type in ( " internlmxcomposer2 " ) :
dt = torch . float16
print ( f " dt change to { dt } " )
tokenizer = AutoTokenizer . from_pretrained ( model_path , trust_remote_code = True )
model = AutoModelForCausalLM . from_pretrained ( model_path , torch_dtype = dt , trust_remote_code = True , device_map = ' auto ' )
model = model . eval ( )
lm = LoadedModel ( model_type , model_path , model , None , tokenizer , dev , dt )
_loaded [ model_path ] = lm
return lm
elif model_type in ( " gemma3 " , " gemma-3 " , " gemma_3 " ) :
model = Gemma3ForConditionalGeneration . from_pretrained (
model_path ,
torch_dtype = dt if dt != torch . float32 else None ,
device_map = None , # we move to device below
trust_remote_code = True ,
)
processor = AutoProcessor . from_pretrained ( model_path , trust_remote_code = True )
model . to ( dev ) . eval ( )
lm = LoadedModel ( model_type , model_path , model , processor , None , dev , dt )
_loaded [ model_path ] = lm
return lm
processor = AutoProcessor . from_pretrained ( model_path , trust_remote_code = True )
# Tokenizer is often part of Processor; still try to load explicitly for safety
try :
tokenizer = AutoTokenizer . from_pretrained ( model_path , trust_remote_code = True , use_fast = True )
except Exception :
tokenizer = None
model = None
errors = [ ]
for candidate in ( AutoModel , AutoModelForVision2Seq , AutoModelForCausalLM ) :
try :
model = candidate . from_pretrained (
model_path ,
torch_dtype = dt if dt != torch . float32 else None ,
device_map = None , # we move to device manually
trust_remote_code = True
)
break
except Exception as e :
errors . append ( str ( e ) )
if model is None :
raise RuntimeError ( f " Unable to load model { model_path } . Errors: { errors } " )
model . to ( dev )
model . eval ( )
lm = LoadedModel ( model_type , model_path , model , processor , tokenizer , dev , dt )
_loaded [ model_path ] = lm
return lm
def unload_model ( model_path : str ) :
if model_path in _loaded :
lm = _loaded . pop ( model_path )
try :
del lm . model
del lm . processor
if lm . tokenizer :
del lm . tokenizer
gc . collect ( )
if lm . device . type == " cuda " :
torch . cuda . empty_cache ( )
except Exception :
pass
# ----------------------------
# Core inference
# ----------------------------
def prepare_inputs ( lm : LoadedModel , prompt : str , images : List [ Image . Image ] ) - > Dict [ str , Any ] :
proc = lm . processor
# If the processor exposes a chat template, use it (covers mllama, qwen2vl, MiniCPM-V2, ...)
if hasattr ( proc , " apply_chat_template " ) :
conversation = [
{ " role " : " user " , " content " : [
{ " type " : " image " } , # one placeholder per image; repeat if >1
{ " type " : " text " , " text " : prompt } ,
] }
]
text_prompt = proc . apply_chat_template ( conversation , add_generation_prompt = True )
encoded = proc (
text = [ text_prompt ] , # list is required by some processors
images = images , # list-of-PIL
padding = True ,
return_tensors = " pt " ,
)
else :
# generic fallback
encoded = proc ( text = prompt , images = images , return_tensors = " pt " )
# Move to the target device
return { k : v . to ( lm . device ) if torch . is_tensor ( v ) else v for k , v in encoded . items ( ) }
def generate_text ( lm : LoadedModel , inputs : Dict [ str , Any ] , gen : GenParams ) - > str :
gen_kwargs = dict (
max_new_tokens = gen . max_new_tokens ,
temperature = gen . temperature ,
top_p = gen . top_p ,
do_sample = gen . do_sample
)
with torch . no_grad ( ) , autocast_ctx ( lm . device , lm . dtype ) :
out_ids = lm . model . generate ( * * inputs , * * gen_kwargs )
# Decode
if lm . tokenizer :
return lm . tokenizer . decode ( out_ids [ 0 ] , skip_special_tokens = True )
# Some processors expose a tokenizer inside
if hasattr ( lm . processor , " tokenizer " ) and lm . processor . tokenizer is not None :
return lm . processor . tokenizer . decode ( out_ids [ 0 ] , skip_special_tokens = True )
# Last resort
try :
return lm . model . decode ( out_ids [ 0 ] )
except Exception :
return " <decode_failed> "
def time_block ( fn , sync , * args , * * kwargs ) - > Tuple [ Any , float ] :
start = time . perf_counter ( )
out = fn ( * args , * * kwargs )
sync ( )
dur_ms = ( time . perf_counter ( ) - start ) * 1000.0
return out , dur_ms
# ----------------------------
# Routes
# ----------------------------
@app.get ( " /health " )
def health ( ) :
dev = best_device ( )
return {
" status " : " ok " ,
" device " : str ( dev ) ,
" device_name " : device_name ( dev ) ,
" torch " : torch . __version__ ,
" cuda_available " : torch . cuda . is_available ( ) ,
" mps_available " : hasattr ( torch . backends , " mps " ) and torch . backends . mps . is_available ( ) ,
" xpu_available " : hasattr ( torch , " xpu " ) and torch . xpu . is_available ( ) ,
}
@app.get ( " /info " )
def info ( ) :
dev = best_device ( )
return {
" device " : {
" type " : dev . type ,
" name " : device_name ( dev ) ,
" total_memory_gb " : device_total_mem_gb ( dev )
} ,
" torch " : torch . __version__ ,
" transformers " : __import__ ( " transformers " ) . __version__
}
@app.post ( " /load_model " )
def load_model ( req : LoadModelRequest ) :
lm = resolve_model ( req . model_path , req . dtype )
print ( f " model with path { req . model_path } loaded! " )
return {
" loaded " : lm . model_path ,
" device " : str ( lm . device ) ,
" dtype " : str ( lm . dtype )
}
@app.post ( " /unload_model " )
def unload ( req : UnloadModelRequest ) :
unload_model ( req . model_path )
return { " unloaded " : req . model }
def handle_normal_case ( lm : LoadedModel , warmup_runs : int , images : List [ str ] , prompt : str , generation : GenParams ) :
# Warmup
for _ in range ( max ( 0 , warmup_runs ) ) :
try :
_ = generate_text (
lm ,
prepare_inputs ( lm , " Hello " , [ Image . new ( " RGB " , ( 64 , 64 ) , color = ( 128 , 128 , 128 ) ) ] ) ,
GenParams ( max_new_tokens = 8 )
)
synchronize ( lm . device )
except Exception :
break
# Load images
try :
pil_images = [ load_image ( s ) for s in images ]
except Exception as e :
raise HTTPException ( status_code = 400 , detail = f " Failed to load images: { e } " )
# Timed steps
synchronize ( lm . device )
_ , t_pre = time_block ( lambda : prepare_inputs ( lm , prompt , pil_images ) , lambda : synchronize ( lm . device ) )
inputs = _
text , t_gen = time_block ( lambda : generate_text ( lm , inputs , generation ) , lambda : synchronize ( lm . device ) )
# for future use, useful for cleaning, parsing, transforming model output
_ , t_post = time_block ( lambda : None , lambda : synchronize ( lm . device ) ) # placeholder if you add detokenization etc.
return text , t_pre , t_gen , t_post
def handle_minicpmv ( lm : LoadedModel , image : Image . Image , prompt : str , gen : GenParams ) :
def generate_text_chat ( ) - > str :
# Prepare msgs in the format expected by model.chat
msgs = [ { " role " : " user " , " content " : prompt } ]
# Call the model's built-in chat method
response = lm . model . chat (
image = image ,
msgs = msgs ,
tokenizer = lm . tokenizer ,
sampling = gen . do_sample ,
temperature = gen . temperature ,
stream = False # Set True if you want streaming later
)
return response
# Run chat-based inference
synchronize ( lm . device )
text , t_gen = time_block (
lambda : generate_text_chat ( ) ,
lambda : synchronize ( lm . device )
)
t_pre , t_post = 0.0 , 0.0 # Not needed with chat API
return text , t_pre , t_gen , t_post
def handle_internlm_xcomposer ( lm : LoadedModel ,
images_pil : List [ Image . Image ] ,
prompt : str ,
gen : GenParams ) :
def generate_text_chat ( ) :
# 1️ ⃣ preprocess every image with the model-supplied CLIP transform
imgs = [ lm . model . vis_processor ( img . convert ( " RGB " ) ) for img in images_pil ]
batch = torch . stack ( imgs ) . to ( lm . device , dtype = lm . dtype )
# 2️ ⃣ build the query string – one <ImageHere> token per picture
query = ( " <ImageHere> " * len ( images_pil ) ) . strip ( ) + " " + prompt
# 3️ ⃣ run chat-style generation
with torch . no_grad ( ) , autocast_ctx ( lm . device , lm . dtype ) :
response , _ = lm . model . chat (
lm . tokenizer ,
query = query ,
image = batch ,
history = [ ] ,
do_sample = gen . do_sample ,
temperature = gen . temperature ,
max_new_tokens = gen . max_new_tokens ,
)
return response
# Run chat-based inference
synchronize ( lm . device )
text , t_gen = time_block (
lambda : generate_text_chat ( ) ,
lambda : synchronize ( lm . device )
)
t_pre , t_post = 0.0 , 0.0 # Not needed with chat API
return text , t_pre , t_gen , t_post
def handle_qwen2vl ( lm : LoadedModel , image_strings : List [ str ] , prompt : str , gen : GenParams ) :
images = [ load_image ( s ) for s in image_strings ]
image = images [ 0 ]
conversation = [
{
" role " : " user " ,
" content " : [
{ " type " : " image " } ,
{ " type " : " text " , " text " : prompt } ,
] ,
}
]
text_prompt = lm . processor . apply_chat_template ( conversation , add_generation_prompt = True )
inputs = lm . processor (
text = [ text_prompt ] ,
images = [ image ] ,
padding = True ,
return_tensors = " pt "
)
inputs = inputs . to ( lm . device )
output_ids = lm . model . generate ( * * inputs , max_new_tokens = gen . max_new_tokens )
generated_ids = [
output_ids [ len ( input_ids ) : ]
for input_ids , output_ids in zip ( inputs . input_ids , output_ids )
]
output_text = lm . processor . batch_decode (
generated_ids , skip_special_tokens = True , clean_up_tokenization_spaces = True
)
return output_text , 0.0 , 0.0 , 0.0 # dummy timing values
def handle_gemma3 ( lm : LoadedModel , image_refs : List [ str ] , prompt : str , gen : GenParams ) :
img = load_image ( image_refs [ 0 ] )
messages = [
{ " role " : " user " , " content " : [
{ " type " : " image " } ,
{ " type " : " text " , " text " : prompt } ,
] }
]
text_prompt = lm . processor . apply_chat_template ( messages , add_generation_prompt = True )
inputs = lm . processor (
img ,
text_prompt ,
add_special_tokens = False ,
return_tensors = " pt "
) . to ( lm . device )
# Run chat-based inference
synchronize ( lm . device )
out_ids , t_gen = time_block (
lambda : lm . model . generate ( * * inputs , max_new_tokens = gen . max_new_tokens ) ,
lambda : synchronize ( lm . device )
)
t_pre , t_post = 0.0 , 0.0 # Not needed with chat API
return lm . processor . decode ( out_ids [ 0 ] , skip_special_tokens = True ) , t_pre , t_gen , t_post
@app.post ( " /infer " , response_model = InferResponse )
def infer ( req : InferRequest ) :
print ( " infer got " )
# Load / reuse model
lm = resolve_model ( req . model_path , req . dtype )
print ( f " { lm . model_type =} " )
if lm . model_type == ' minicpmv ' :
text , t_pre , t_gen , t_post = handle_minicpmv ( lm , load_image ( req . images [ 0 ] ) , req . prompt , req . generation )
elif lm . model_type in ( " qwen2vl " , " qwen2-vl " , " qwen2_vl " ) :
text , t_pre , t_gen , t_post = handle_qwen2vl ( lm , req . images , req . prompt , req . generation )
elif lm . model_type == " internlmxcomposer2 " :
# Load images
try :
pil_images = [ load_image ( s ) for s in req . images ]
except Exception as e :
raise HTTPException ( status_code = 400 , detail = f " Failed to load images: { e } " )
text , t_pre , t_gen , t_post = handle_internlm_xcomposer ( lm , pil_images , req . prompt , req . generation )
else :
text , t_pre , t_gen , t_post = handle_normal_case ( lm , req . warmup_runs , req . images , req . prompt , req . generation )
timings = {
" preprocess " : t_pre ,
" generate " : t_gen ,
" postprocess " : t_post ,
" e2e " : t_pre + t_gen + t_post
}
return InferResponse (
output_text = text ,
timings_ms = timings ,
device = {
" type " : lm . device . type ,
" name " : device_name ( lm . device ) ,
" total_memory_gb " : device_total_mem_gb ( lm . device )
} ,
model_info = {
" name " : lm . model_path ,
" precision " : str ( lm . dtype ) . replace ( " torch. " , " " ) ,
" framework " : " transformers "
}
)
# Entry
# Run: uvicorn server:app --host 0.0.0.0 --port 8000