import base64 import gc import io import os import time from typing import List, Optional, Dict, Any, Tuple import torch from PIL import Image from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import ( AutoProcessor, AutoTokenizer, AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoModel ) try: from transformers import (Qwen2VLForConditionalGeneration, Gemma3ForConditionalGeneration) except ImportError: pass app = FastAPI(title="Unified VLM API (Transformers)") # ---------------------------- # Device selection & sync # ---------------------------- def best_device() -> torch.device: # CUDA covers NVIDIA and AMD ROCm builds under torch.cuda if torch.cuda.is_available(): return torch.device("cuda") # Intel oneAPI/XPU (ipex) if hasattr(torch, "xpu") and torch.xpu.is_available(): return torch.device("xpu") # Apple MPS if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") def device_name(device: torch.device) -> str: if device.type == "cuda": try: return torch.cuda.get_device_name(0) except Exception: return "CUDA device" if device.type == "xpu": return "Intel XPU" if device.type == "mps": return "Apple MPS" return "CPU" def device_total_mem_gb(device: torch.device) -> Optional[float]: try: if device.type == "cuda": return torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) # For others, memory reporting varies return None except Exception: return None def synchronize(device: torch.device): # Ensure accurate wall times for GPU work try: if device.type == "cuda": torch.cuda.synchronize() elif device.type == "xpu" and hasattr(torch, "xpu"): torch.xpu.synchronize() elif device.type == "mps" and hasattr(torch, "mps"): torch.mps.synchronize() except Exception: pass # ---------------------------- # Model registry # ---------------------------- class LoadedModel: def __init__(self, model_type: str, model_path: str, model, processor, tokenizer, device: torch.device, dtype: torch.dtype): self.model_type = model_type self.model_path = model_path self.model = model self.processor = processor self.tokenizer = tokenizer self.device = device self.dtype = dtype _loaded: Dict[str, LoadedModel] = {} # ---------------------------- # IO helpers # ---------------------------- def load_image(ref: str) -> Image.Image: # Accept http(s) URLs, local paths, or base64 data URLs if ref.startswith("http://") or ref.startswith("https://"): # Rely on HF file utilities only if you want offline; here use requests lazily import requests r = requests.get(ref, timeout=30) r.raise_for_status() return Image.open(io.BytesIO(r.content)).convert("RGB") if os.path.exists(ref): return Image.open(ref).convert("RGB") # Base64 if ref.startswith("data:image"): header, b64 = ref.split(",", 1) return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB") # Raw base64 try: return Image.open(io.BytesIO(base64.b64decode(ref))).convert("RGB") except Exception: raise ValueError(f"Unsupported image reference: {ref[:80]}...") def pick_dtype(req_dtype: str, device: torch.device) -> torch.dtype: if req_dtype == "float16": return torch.float16 if req_dtype == "bfloat16": return torch.bfloat16 if req_dtype == "float32": return torch.float32 # auto if device.type in ("cuda", "xpu"): # bfloat16 works broadly on modern GPUs; fall back to float16 for older CUDA try: return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 except Exception: return torch.float16 if device.type == "mps": return torch.float16 return torch.float32 def autocast_ctx(device: torch.device, dtype: torch.dtype): if device.type == "cpu": return torch.autocast(device_type="cpu", dtype=dtype) if device.type == "cuda": return torch.autocast(device_type="cuda", dtype=dtype) if device.type == "xpu": return torch.autocast(device_type="xpu", dtype=dtype) if device.type == "mps": return torch.autocast(device_type="mps", dtype=dtype) # fallback no-op from contextlib import nullcontext return nullcontext() # ---------------------------- # Requests/Responses # ---------------------------- class GenParams(BaseModel): max_new_tokens: int = 128 temperature: float = 0.0 top_p: float = 1.0 do_sample: bool = False class InferRequest(BaseModel): model_path: str prompt: str images: List[str] generation: GenParams = GenParams() dtype: str = "auto" # "auto"|"float16"|"bfloat16"|"float32" warmup_runs: int = 1 measure_token_times: bool = False class InferResponse(BaseModel): output_text: str timings_ms: Dict[str, float] device: Dict[str, Any] model_info: Dict[str, Any] class LoadModelRequest(BaseModel): model_path: str dtype: str = "auto" class UnloadModelRequest(BaseModel): model_path: str # ---------------------------- # Model loading # ---------------------------- def resolve_model(model_path: str, dtype_str: str) -> LoadedModel: if model_path in _loaded: return _loaded[model_path] dev = best_device() dt = pick_dtype(dtype_str, dev) cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True) model_type = cfg.model_type print(f"model type detected: {model_type}, device: {dev}, dt: {dt}") if model_type in ("qwen2_vl", "qwen2-vl"): print(f"Loading Qwen2-VL using Qwen2VLForConditionalGeneration") model = Qwen2VLForConditionalGeneration.from_pretrained( model_path, torch_dtype=dt if dt != torch.float32 else None, device_map=None, trust_remote_code=True, ) print("Loaded model class:", type(model)) processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) # tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model.to(dev) model.eval() lm = LoadedModel(model_type, model_path, model, processor, None, dev, dt) _loaded[model_path] = lm return lm elif model_type in ("internlmxcomposer2"): dt = torch.float16 print(f"dt change to {dt}") tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=dt, trust_remote_code=True, device_map='auto') model = model.eval() lm = LoadedModel(model_type, model_path, model, None, tokenizer, dev, dt) _loaded[model_path] = lm return lm elif model_type in ("gemma3", "gemma-3", "gemma_3"): model = Gemma3ForConditionalGeneration.from_pretrained( model_path, torch_dtype=dt if dt != torch.float32 else None, device_map=None, # we move to device below trust_remote_code=True, ) processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) model.to(dev).eval() lm = LoadedModel(model_type, model_path, model, processor, None, dev, dt) _loaded[model_path] = lm return lm processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) # Tokenizer is often part of Processor; still try to load explicitly for safety try: tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=True) except Exception: tokenizer = None model = None errors = [] for candidate in (AutoModel, AutoModelForVision2Seq, AutoModelForCausalLM): try: model = candidate.from_pretrained( model_path, torch_dtype=dt if dt != torch.float32 else None, device_map=None, # we move to device manually trust_remote_code=True ) break except Exception as e: errors.append(str(e)) if model is None: raise RuntimeError(f"Unable to load model {model_path}. Errors: {errors}") model.to(dev) model.eval() lm = LoadedModel(model_type, model_path, model, processor, tokenizer, dev, dt) _loaded[model_path] = lm return lm def unload_model(model_path: str): if model_path in _loaded: lm = _loaded.pop(model_path) try: del lm.model del lm.processor if lm.tokenizer: del lm.tokenizer gc.collect() if lm.device.type == "cuda": torch.cuda.empty_cache() except Exception: pass # ---------------------------- # Core inference # ---------------------------- def prepare_inputs(lm: LoadedModel, prompt: str, images: List[Image.Image]) -> Dict[str, Any]: proc = lm.processor # If the processor exposes a chat template, use it (covers mllama, qwen2vl, MiniCPM-V2, ...) if hasattr(proc, "apply_chat_template"): conversation = [ {"role": "user", "content": [ {"type": "image"}, # one placeholder per image; repeat if >1 {"type": "text", "text": prompt}, ]} ] text_prompt = proc.apply_chat_template(conversation, add_generation_prompt=True) encoded = proc( text=[text_prompt], # list is required by some processors images=images, # list-of-PIL padding=True, return_tensors="pt", ) else: # generic fallback encoded = proc(text=prompt, images=images, return_tensors="pt") # Move to the target device return {k: v.to(lm.device) if torch.is_tensor(v) else v for k, v in encoded.items()} def generate_text(lm: LoadedModel, inputs: Dict[str, Any], gen: GenParams) -> str: gen_kwargs = dict( max_new_tokens=gen.max_new_tokens, temperature=gen.temperature, top_p=gen.top_p, do_sample=gen.do_sample ) with torch.no_grad(), autocast_ctx(lm.device, lm.dtype): out_ids = lm.model.generate(**inputs, **gen_kwargs) # Decode if lm.tokenizer: return lm.tokenizer.decode(out_ids[0], skip_special_tokens=True) # Some processors expose a tokenizer inside if hasattr(lm.processor, "tokenizer") and lm.processor.tokenizer is not None: return lm.processor.tokenizer.decode(out_ids[0], skip_special_tokens=True) # Last resort try: return lm.model.decode(out_ids[0]) except Exception: return "" def time_block(fn, sync, *args, **kwargs) -> Tuple[Any, float]: start = time.perf_counter() out = fn(*args, **kwargs) sync() dur_ms = (time.perf_counter() - start) * 1000.0 return out, dur_ms # ---------------------------- # Routes # ---------------------------- @app.get("/health") def health(): dev = best_device() return { "status": "ok", "device": str(dev), "device_name": device_name(dev), "torch": torch.__version__, "cuda_available": torch.cuda.is_available(), "mps_available": hasattr(torch.backends, "mps") and torch.backends.mps.is_available(), "xpu_available": hasattr(torch, "xpu") and torch.xpu.is_available(), } @app.get("/info") def info(): dev = best_device() return { "device": { "type": dev.type, "name": device_name(dev), "total_memory_gb": device_total_mem_gb(dev) }, "torch": torch.__version__, "transformers": __import__("transformers").__version__ } @app.post("/load_model") def load_model(req: LoadModelRequest): lm = resolve_model(req.model_path, req.dtype) print(f"model with path {req.model_path} loaded!") return { "loaded": lm.model_path, "device": str(lm.device), "dtype": str(lm.dtype) } @app.post("/unload_model") def unload(req: UnloadModelRequest): unload_model(req.model_path) return {"unloaded": req.model} def handle_normal_case(lm: LoadedModel, warmup_runs: int, images: List[str], prompt: str, generation: GenParams): # Warmup for _ in range(max(0, warmup_runs)): try: _ = generate_text( lm, prepare_inputs(lm, "Hello", [Image.new("RGB", (64, 64), color=(128, 128, 128))]), GenParams(max_new_tokens=8) ) synchronize(lm.device) except Exception: break # Load images try: pil_images = [load_image(s) for s in images] except Exception as e: raise HTTPException(status_code=400, detail=f"Failed to load images: {e}") # Timed steps synchronize(lm.device) _, t_pre = time_block(lambda: prepare_inputs(lm, prompt, pil_images), lambda: synchronize(lm.device)) inputs = _ text, t_gen = time_block(lambda: generate_text(lm, inputs, generation), lambda: synchronize(lm.device)) # for future use, useful for cleaning, parsing, transforming model output _, t_post = time_block(lambda: None, lambda: synchronize(lm.device)) # placeholder if you add detokenization etc. return text, t_pre, t_gen, t_post def handle_minicpmv(lm: LoadedModel, image: Image.Image, prompt: str, gen: GenParams): def generate_text_chat() -> str: # Prepare msgs in the format expected by model.chat msgs = [{"role": "user", "content": prompt}] # Call the model's built-in chat method response = lm.model.chat( image=image, msgs=msgs, tokenizer=lm.tokenizer, sampling=gen.do_sample, temperature=gen.temperature, stream=False # Set True if you want streaming later ) return response # Run chat-based inference synchronize(lm.device) text, t_gen = time_block( lambda: generate_text_chat(), lambda: synchronize(lm.device) ) t_pre, t_post = 0.0, 0.0 # Not needed with chat API return text, t_pre, t_gen, t_post def handle_internlm_xcomposer(lm: LoadedModel, images_pil: List[Image.Image], prompt: str, gen: GenParams): def generate_text_chat(): # 1️⃣ preprocess every image with the model-supplied CLIP transform imgs = [lm.model.vis_processor(img.convert("RGB")) for img in images_pil] batch = torch.stack(imgs).to(lm.device, dtype=lm.dtype) # 2️⃣ build the query string – one token per picture query = (" " * len(images_pil)).strip() + " " + prompt # 3️⃣ run chat-style generation with torch.no_grad(), autocast_ctx(lm.device, lm.dtype): response, _ = lm.model.chat( lm.tokenizer, query=query, image=batch, history=[], do_sample=gen.do_sample, temperature=gen.temperature, max_new_tokens=gen.max_new_tokens, ) return response # Run chat-based inference synchronize(lm.device) text, t_gen = time_block( lambda: generate_text_chat(), lambda: synchronize(lm.device) ) t_pre, t_post = 0.0, 0.0 # Not needed with chat API return text, t_pre, t_gen, t_post def handle_qwen2vl(lm: LoadedModel, image_strings: List[str], prompt: str, gen: GenParams): images = [load_image(s) for s in image_strings] image = images[0] conversation = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": prompt}, ], } ] text_prompt = lm.processor.apply_chat_template(conversation, add_generation_prompt=True) inputs = lm.processor( text=[text_prompt], images=[image], padding=True, return_tensors="pt" ) inputs = inputs.to(lm.device) output_ids = lm.model.generate(**inputs, max_new_tokens=gen.max_new_tokens) generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids) ] output_text = lm.processor.batch_decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True ) return output_text, 0.0, 0.0, 0.0 # dummy timing values def handle_gemma3(lm: LoadedModel, image_refs: List[str], prompt: str, gen: GenParams): img = load_image(image_refs[0]) messages = [ {"role": "user", "content": [ {"type": "image"}, {"type": "text", "text": prompt}, ]} ] text_prompt = lm.processor.apply_chat_template(messages, add_generation_prompt=True) inputs = lm.processor( img, text_prompt, add_special_tokens=False, return_tensors="pt" ).to(lm.device) # Run chat-based inference synchronize(lm.device) out_ids, t_gen = time_block( lambda: lm.model.generate(**inputs, max_new_tokens=gen.max_new_tokens), lambda: synchronize(lm.device) ) t_pre, t_post = 0.0, 0.0 # Not needed with chat API return lm.processor.decode(out_ids[0], skip_special_tokens=True), t_pre, t_gen, t_post @app.post("/infer", response_model=InferResponse) def infer(req: InferRequest): print("infer got") # Load / reuse model lm = resolve_model(req.model_path, req.dtype) print(f"{lm.model_type=}") if lm.model_type == 'minicpmv': text, t_pre, t_gen, t_post = handle_minicpmv(lm, load_image(req.images[0]), req.prompt, req.generation) elif lm.model_type in ("qwen2vl", "qwen2-vl", "qwen2_vl"): text, t_pre, t_gen, t_post = handle_qwen2vl(lm, req.images, req.prompt, req.generation) elif lm.model_type == "internlmxcomposer2": # Load images try: pil_images = [load_image(s) for s in req.images] except Exception as e: raise HTTPException(status_code=400, detail=f"Failed to load images: {e}") text, t_pre, t_gen, t_post = handle_internlm_xcomposer(lm, pil_images, req.prompt, req.generation) else: text, t_pre, t_gen, t_post = handle_normal_case(lm, req.warmup_runs, req.images, req.prompt, req.generation) timings = { "preprocess": t_pre, "generate": t_gen, "postprocess": t_post, "e2e": t_pre + t_gen + t_post } return InferResponse( output_text=text, timings_ms=timings, device={ "type": lm.device.type, "name": device_name(lm.device), "total_memory_gb": device_total_mem_gb(lm.device) }, model_info={ "name": lm.model_path, "precision": str(lm.dtype).replace("torch.", ""), "framework": "transformers" } ) # Entry # Run: uvicorn server:app --host 0.0.0.0 --port 8000