import base64 import gc import io import os import time from typing import List, Optional, Dict, Any, Tuple import torch try: import torch_mlu except ImportError: pass from PIL import Image from fastapi import FastAPI, HTTPException, Query from pydantic import BaseModel from transformers import ( AutoProcessor, AutoTokenizer, AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoModel ) try: from transformers import (Qwen2VLForConditionalGeneration, Gemma3ForConditionalGeneration) except ImportError: pass app = FastAPI(title="Unified VLM API (Transformers)") # ---------------------------- # Device selection & sync # ---------------------------- def best_device() -> torch.device: # CUDA covers NVIDIA and AMD ROCm builds under torch.cuda if torch.cuda.is_available(): return torch.device("cuda") # Intel oneAPI/XPU (ipex) if hasattr(torch, "xpu") and torch.xpu.is_available(): return torch.device("xpu") # Apple MPS if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return torch.device("mps") # cambricon MLU if hasattr(torch, "mlu") and torch.mlu.is_available(): return torch.device("mlu") return torch.device("cpu") def device_name(device: torch.device) -> str: if device.type == "cuda": try: return torch.cuda.get_device_name(0) except Exception: return "CUDA device" if device.type == "xpu": return "Intel XPU" if device.type == "mps": return "Apple MPS" if device.type == "mlu": return "cambricon MLU" return "CPU" def device_total_mem_gb(device: torch.device) -> Optional[float]: try: if device.type == "cuda": return torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) # For others, memory reporting varies return None except Exception: return None def synchronize(device: torch.device): # Ensure accurate wall times for GPU work try: if device.type == "cuda": torch.cuda.synchronize() elif device.type == "xpu" and hasattr(torch, "xpu"): torch.xpu.synchronize() elif device.type == "mps" and hasattr(torch, "mps"): torch.mps.synchronize() except Exception: pass # ---------------------------- # Model registry # ---------------------------- class LoadedModel: def __init__(self, model_type: str, model_path: str, model, processor, tokenizer, device: torch.device, dtype: torch.dtype): self.model_type = model_type self.model_path = model_path self.model = model self.processor = processor self.tokenizer = tokenizer self.device = device self.dtype = dtype _loaded: Dict[str, LoadedModel] = {} # ---------------------------- # IO helpers # ---------------------------- def load_image(ref: str) -> Image.Image: # Accept http(s) URLs, local paths, or base64 data URLs if ref.startswith("http://") or ref.startswith("https://"): # Rely on HF file utilities only if you want offline; here use requests lazily import requests r = requests.get(ref, timeout=30) r.raise_for_status() return Image.open(io.BytesIO(r.content)).convert("RGB") if os.path.exists(ref): return Image.open(ref).convert("RGB") # Base64 if ref.startswith("data:image"): header, b64 = ref.split(",", 1) return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB") # Raw base64 try: return Image.open(io.BytesIO(base64.b64decode(ref))).convert("RGB") except Exception: raise ValueError(f"Unsupported image reference: {ref[:80]}...") def pick_dtype(req_dtype: str, device: torch.device) -> torch.dtype: if req_dtype == "float16": return torch.float16 if req_dtype == "bfloat16": return torch.bfloat16 if req_dtype == "float32": return torch.float32 # auto if device.type in ("cuda", "xpu"): # bfloat16 works broadly on modern GPUs; fall back to float16 for older CUDA try: return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 except Exception: return torch.float16 if device.type == "mps" or device.type == "mlu": return torch.float16 return torch.float32 def autocast_ctx(device: torch.device, dtype: torch.dtype): if device.type == "cpu": return torch.autocast(device_type="cpu", dtype=dtype) if device.type == "cuda": return torch.autocast(device_type="cuda", dtype=dtype) if device.type == "xpu": return torch.autocast(device_type="xpu", dtype=dtype) if device.type == "mps": return torch.autocast(device_type="mps", dtype=dtype) # fallback no-op from contextlib import nullcontext return nullcontext() # ---------------------------- # Requests/Responses # ---------------------------- class GenParams(BaseModel): max_new_tokens: int = 128 temperature: float = 0.0 top_p: float = 1.0 do_sample: bool = False class InferRequest(BaseModel): model_path: str prompt: str images: List[str] generation: GenParams = GenParams() dtype: str = "auto" # "auto"|"float16"|"bfloat16"|"float32" warmup_runs: int = 1 measure_token_times: bool = False class InferResponse(BaseModel): output_text: str timings_ms: Dict[str, float] device: Dict[str, Any] model_info: Dict[str, Any] class LoadModelRequest(BaseModel): model_path: str dtype: str = "auto" class UnloadModelRequest(BaseModel): model_path: str # ---------------------------- # Model loading # ---------------------------- def resolve_model(model_path: str, dtype_str: str) -> LoadedModel: if model_path in _loaded: return _loaded[model_path] dev = best_device() dt = pick_dtype(dtype_str, dev) cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True) model_type = cfg.model_type print(f"model type detected: {model_type}, device: {dev}, dt: {dt}") if model_type in ("qwen2_vl", "qwen2-vl"): print(f"Loading Qwen2-VL using Qwen2VLForConditionalGeneration") model = Qwen2VLForConditionalGeneration.from_pretrained( model_path, torch_dtype=dt if dt != torch.float32 else None, device_map=None, trust_remote_code=True, ) print("Loaded model class:", type(model)) processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) # tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model.to(dev) model.eval() lm = LoadedModel(model_type, model_path, model, processor, None, dev, dt) _loaded[model_path] = lm return lm elif model_type in ("internlmxcomposer2"): dt = torch.float16 print(f"dt change to {dt}") tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=dt, trust_remote_code=True, device_map='auto') model = model.eval() lm = LoadedModel(model_type, model_path, model, None, tokenizer, dev, dt) _loaded[model_path] = lm return lm elif model_type in ("gemma3", "gemma-3", "gemma_3"): model = Gemma3ForConditionalGeneration.from_pretrained( model_path, torch_dtype=dt if dt != torch.float32 else None, device_map=None, # we move to device below trust_remote_code=True, ) processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) model.to(dev).eval() lm = LoadedModel(model_type, model_path, model, processor, None, dev, dt) _loaded[model_path] = lm return lm processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) # Tokenizer is often part of Processor; still try to load explicitly for safety try: tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=True) except Exception: tokenizer = None model = None errors = [] for candidate in (AutoModel, AutoModelForVision2Seq, AutoModelForCausalLM): try: model = candidate.from_pretrained( model_path, torch_dtype=dt if dt != torch.float32 else None, device_map=None, # we move to device manually trust_remote_code=True ) break except Exception as e: errors.append(str(e)) if model is None: raise RuntimeError(f"Unable to load model {model_path}. Errors: {errors}") model.to(dev) model.eval() lm = LoadedModel(model_type, model_path, model, processor, tokenizer, dev, dt) _loaded[model_path] = lm return lm def unload_model(model_path: str): if model_path in _loaded: lm = _loaded.pop(model_path) try: del lm.model del lm.processor if lm.tokenizer: del lm.tokenizer gc.collect() if lm.device.type == "cuda": torch.cuda.empty_cache() except Exception: pass # ---------------------------- # Core inference # ---------------------------- def prepare_inputs(lm: LoadedModel, prompt: str, images: List[Image.Image]) -> Dict[str, Any]: proc = lm.processor # If the processor exposes a chat template, use it (covers mllama, qwen2vl, MiniCPM-V2, ...) if hasattr(proc, "apply_chat_template"): conversation = [ {"role": "user", "content": [ {"type": "image"}, # one placeholder per image; repeat if >1 {"type": "text", "text": prompt}, ]} ] text_prompt = proc.apply_chat_template(conversation, add_generation_prompt=True) encoded = proc( text=[text_prompt], # list is required by some processors images=images, # list-of-PIL padding=True, return_tensors="pt", ) else: # generic fallback encoded = proc(text=prompt, images=images, return_tensors="pt") # Move to the target device return {k: v.to(lm.device) if torch.is_tensor(v) else v for k, v in encoded.items()} def generate_text(lm: LoadedModel, inputs: Dict[str, Any], gen: GenParams) -> str: gen_kwargs = dict( max_new_tokens=gen.max_new_tokens, temperature=gen.temperature, top_p=gen.top_p, do_sample=gen.do_sample ) with torch.no_grad(), autocast_ctx(lm.device, lm.dtype): out_ids = lm.model.generate(**inputs, **gen_kwargs) # Decode if lm.tokenizer: return lm.tokenizer.decode(out_ids[0], skip_special_tokens=True) # Some processors expose a tokenizer inside if hasattr(lm.processor, "tokenizer") and lm.processor.tokenizer is not None: return lm.processor.tokenizer.decode(out_ids[0], skip_special_tokens=True) # Last resort try: return lm.model.decode(out_ids[0]) except Exception: return "" def time_block(fn, sync, *args, **kwargs) -> Tuple[Any, float]: start = time.perf_counter() out = fn(*args, **kwargs) sync() dur_ms = (time.perf_counter() - start) * 1000.0 return out, dur_ms # ---------------------------- # Routes # ---------------------------- @app.get("/health") def health(): dev = best_device() return { "status": "ok", "device": str(dev), "device_name": device_name(dev), "torch": torch.__version__, "cuda_available": torch.cuda.is_available(), "mps_available": hasattr(torch.backends, "mps") and torch.backends.mps.is_available(), "xpu_available": hasattr(torch, "xpu") and torch.xpu.is_available(), } @app.get("/info") def info(): dev = best_device() return { "device": { "type": dev.type, "name": device_name(dev), "total_memory_gb": device_total_mem_gb(dev) }, "torch": torch.__version__, "transformers": __import__("transformers").__version__ } @app.post("/load_model") def load_model(req: LoadModelRequest): lm = resolve_model(req.model_path, req.dtype) print(f"model with path {req.model_path} loaded!") return { "loaded": lm.model_path, "device": str(lm.device), "dtype": str(lm.dtype) } @app.get("/model_status") def model_status(model_path: str = Query(..., description="Path of the model to check")): lm = _loaded.get(model_path) if lm: return { "loaded": True, "model_path": lm.model_path, "device": str(lm.device), "dtype": str(lm.dtype) } return { "loaded": False, "model_path": model_path } @app.post("/unload_model") def unload(req: UnloadModelRequest): unload_model(req.model_path) return {"unloaded": req.model} def handle_normal_case(lm: LoadedModel, warmup_runs: int, images: List[str], prompt: str, generation: GenParams): # Warmup for _ in range(max(0, warmup_runs)): try: _ = generate_text( lm, prepare_inputs(lm, "Hello", [Image.new("RGB", (64, 64), color=(128, 128, 128))]), GenParams(max_new_tokens=8) ) synchronize(lm.device) except Exception: break # Load images try: pil_images = [load_image(s) for s in images] except Exception as e: raise HTTPException(status_code=400, detail=f"Failed to load images: {e}") # Timed steps synchronize(lm.device) _, t_pre = time_block(lambda: prepare_inputs(lm, prompt, pil_images), lambda: synchronize(lm.device)) inputs = _ text, t_gen = time_block(lambda: generate_text(lm, inputs, generation), lambda: synchronize(lm.device)) # for future use, useful for cleaning, parsing, transforming model output _, t_post = time_block(lambda: None, lambda: synchronize(lm.device)) # placeholder if you add detokenization etc. return text, t_pre, t_gen, t_post def handle_minicpmv(lm: LoadedModel, image: Image.Image, prompt: str, gen: GenParams): def generate_text_chat() -> str: # Prepare msgs in the format expected by model.chat msgs = [{"role": "user", "content": prompt}] # Call the model's built-in chat method response = lm.model.chat( image=image, msgs=msgs, tokenizer=lm.tokenizer, sampling=gen.do_sample, temperature=gen.temperature, stream=False # Set True if you want streaming later ) return response # Run chat-based inference synchronize(lm.device) text, t_gen = time_block( lambda: generate_text_chat(), lambda: synchronize(lm.device) ) t_pre, t_post = 0.0, 0.0 # Not needed with chat API return text, t_pre, t_gen, t_post def handle_internlm_xcomposer(lm: LoadedModel, images_pil: List[Image.Image], prompt: str, gen: GenParams): def generate_text_chat(): # 1️⃣ preprocess every image with the model-supplied CLIP transform imgs = [lm.model.vis_processor(img.convert("RGB")) for img in images_pil] batch = torch.stack(imgs).to(lm.device, dtype=lm.dtype) # 2️⃣ build the query string – one token per picture query = (" " * len(images_pil)).strip() + " " + prompt # 3️⃣ run chat-style generation with torch.no_grad(), autocast_ctx(lm.device, lm.dtype): response, _ = lm.model.chat( lm.tokenizer, query=query, image=batch, history=[], do_sample=gen.do_sample, temperature=gen.temperature, max_new_tokens=gen.max_new_tokens, ) return response # Run chat-based inference synchronize(lm.device) text, t_gen = time_block( lambda: generate_text_chat(), lambda: synchronize(lm.device) ) t_pre, t_post = 0.0, 0.0 # Not needed with chat API return text, t_pre, t_gen, t_post def handle_qwen2vl(lm: LoadedModel, image_strings: List[str], prompt: str, gen: GenParams): images = [load_image(s) for s in image_strings] image = images[0] conversation = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": prompt}, ], } ] text_prompt = lm.processor.apply_chat_template(conversation, add_generation_prompt=True) inputs = lm.processor( text=[text_prompt], images=[image], padding=True, return_tensors="pt" ) inputs = inputs.to(lm.device) output_ids = lm.model.generate(**inputs, max_new_tokens=gen.max_new_tokens) generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids) ] output_text = lm.processor.batch_decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True ) return output_text, 0.0, 0.0, 0.0 # dummy timing values def handle_gemma3(lm: LoadedModel, image_refs: List[str], prompt: str, gen: GenParams): img = load_image(image_refs[0]) messages = [ {"role": "user", "content": [ {"type": "image"}, {"type": "text", "text": prompt}, ]} ] text_prompt = lm.processor.apply_chat_template(messages, add_generation_prompt=True) inputs = lm.processor( img, text_prompt, add_special_tokens=False, return_tensors="pt" ).to(lm.device) # Run chat-based inference synchronize(lm.device) out_ids, t_gen = time_block( lambda: lm.model.generate(**inputs, max_new_tokens=gen.max_new_tokens), lambda: synchronize(lm.device) ) t_pre, t_post = 0.0, 0.0 # Not needed with chat API return lm.processor.decode(out_ids[0], skip_special_tokens=True), t_pre, t_gen, t_post @app.post("/infer", response_model=InferResponse) def infer(req: InferRequest): print("infer got") # Load / reuse model lm = resolve_model(req.model_path, req.dtype) print(f"{lm.model_type=}") if lm.model_type == 'minicpmv': text, t_pre, t_gen, t_post = handle_minicpmv(lm, load_image(req.images[0]), req.prompt, req.generation) elif lm.model_type in ("qwen2vl", "qwen2-vl", "qwen2_vl"): text, t_pre, t_gen, t_post = handle_qwen2vl(lm, req.images, req.prompt, req.generation) elif lm.model_type == "internlmxcomposer2": # Load images try: pil_images = [load_image(s) for s in req.images] except Exception as e: raise HTTPException(status_code=400, detail=f"Failed to load images: {e}") text, t_pre, t_gen, t_post = handle_internlm_xcomposer(lm, pil_images, req.prompt, req.generation) else: text, t_pre, t_gen, t_post = handle_normal_case(lm, req.warmup_runs, req.images, req.prompt, req.generation) timings = { "preprocess": t_pre, "generate": t_gen, "postprocess": t_post, "e2e": t_pre + t_gen + t_post } return InferResponse( output_text=text, timings_ms=timings, device={ "type": lm.device.type, "name": device_name(lm.device), "total_memory_gb": device_total_mem_gb(lm.device) }, model_info={ "name": lm.model_path, "precision": str(lm.dtype).replace("torch.", ""), "framework": "transformers" } ) # Entry # Run: uvicorn server:app --host 0.0.0.0 --port 8000