diff --git a/server.py b/server.py new file mode 100644 index 0000000..0981c5c --- /dev/null +++ b/server.py @@ -0,0 +1,594 @@ +import base64 +import gc +import io +import os +import time +from typing import List, Optional, Dict, Any, Tuple + +import torch +from PIL import Image +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +from transformers import ( + AutoProcessor, + AutoTokenizer, + AutoConfig, + AutoModelForCausalLM, + AutoModelForVision2Seq, AutoModel, Qwen2VLForConditionalGeneration, Gemma3ForConditionalGeneration +) + +app = FastAPI(title="Unified VLM API (Transformers)") + + +# ---------------------------- +# Device selection & sync +# ---------------------------- +def best_device() -> torch.device: + # CUDA covers NVIDIA and AMD ROCm builds under torch.cuda + if torch.cuda.is_available(): + return torch.device("cuda") + # Intel oneAPI/XPU (ipex) + if hasattr(torch, "xpu") and torch.xpu.is_available(): + return torch.device("xpu") + # Apple MPS + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return torch.device("mps") + return torch.device("cpu") + + +def device_name(device: torch.device) -> str: + if device.type == "cuda": + try: + return torch.cuda.get_device_name(0) + except Exception: + return "CUDA device" + if device.type == "xpu": + return "Intel XPU" + if device.type == "mps": + return "Apple MPS" + return "CPU" + + +def device_total_mem_gb(device: torch.device) -> Optional[float]: + try: + if device.type == "cuda": + return torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) + # For others, memory reporting varies + return None + except Exception: + return None + + +def synchronize(device: torch.device): + # Ensure accurate wall times for GPU work + try: + if device.type == "cuda": + torch.cuda.synchronize() + elif device.type == "xpu" and hasattr(torch, "xpu"): + torch.xpu.synchronize() + elif device.type == "mps" and hasattr(torch, "mps"): + torch.mps.synchronize() + except Exception: + pass + + +# ---------------------------- +# Model registry +# ---------------------------- +class LoadedModel: + def __init__(self, model_type: str, model_path: str, model, processor, tokenizer, + device: torch.device, dtype: torch.dtype): + self.model_type = model_type + self.model_path = model_path + self.model = model + self.processor = processor + self.tokenizer = tokenizer + self.device = device + self.dtype = dtype + + +_loaded: Dict[str, LoadedModel] = {} + + +# ---------------------------- +# IO helpers +# ---------------------------- +def load_image(ref: str) -> Image.Image: + # Accept http(s) URLs, local paths, or base64 data URLs + if ref.startswith("http://") or ref.startswith("https://"): + # Rely on HF file utilities only if you want offline; here use requests lazily + import requests + r = requests.get(ref, timeout=30) + r.raise_for_status() + return Image.open(io.BytesIO(r.content)).convert("RGB") + if os.path.exists(ref): + return Image.open(ref).convert("RGB") + # Base64 + if ref.startswith("data:image"): + header, b64 = ref.split(",", 1) + return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB") + # Raw base64 + try: + return Image.open(io.BytesIO(base64.b64decode(ref))).convert("RGB") + except Exception: + raise ValueError(f"Unsupported image reference: {ref[:80]}...") + + +def pick_dtype(req_dtype: str, device: torch.device) -> torch.dtype: + if req_dtype == "float16": + return torch.float16 + if req_dtype == "bfloat16": + return torch.bfloat16 + if req_dtype == "float32": + return torch.float32 + # auto + if device.type in ("cuda", "xpu"): + # bfloat16 works broadly on modern GPUs; fall back to float16 for older CUDA + try: + return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + except Exception: + return torch.float16 + if device.type == "mps": + return torch.float16 + return torch.float32 + + +def autocast_ctx(device: torch.device, dtype: torch.dtype): + if device.type == "cpu": + return torch.autocast(device_type="cpu", dtype=dtype) + if device.type == "cuda": + return torch.autocast(device_type="cuda", dtype=dtype) + if device.type == "xpu": + return torch.autocast(device_type="xpu", dtype=dtype) + if device.type == "mps": + return torch.autocast(device_type="mps", dtype=dtype) + # fallback no-op + from contextlib import nullcontext + return nullcontext() + + +# ---------------------------- +# Requests/Responses +# ---------------------------- +class GenParams(BaseModel): + max_new_tokens: int = 128 + temperature: float = 0.0 + top_p: float = 1.0 + do_sample: bool = False + + +class InferRequest(BaseModel): + model_path: str + prompt: str + images: List[str] + generation: GenParams = GenParams() + dtype: str = "auto" # "auto"|"float16"|"bfloat16"|"float32" + warmup_runs: int = 1 + measure_token_times: bool = False + + +class InferResponse(BaseModel): + output_text: str + timings_ms: Dict[str, float] + device: Dict[str, Any] + model_info: Dict[str, Any] + + +class LoadModelRequest(BaseModel): + model_path: str + dtype: str = "auto" + + +class UnloadModelRequest(BaseModel): + model_path: str + + +# ---------------------------- +# Model loading +# ---------------------------- +def resolve_model(model_path: str, dtype_str: str) -> LoadedModel: + if model_path in _loaded: + return _loaded[model_path] + + dev = best_device() + dt = pick_dtype(dtype_str, dev) + + cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + model_type = cfg.model_type + print(f"model type detected: {model_type}, device: {dev}, dt: {dt}") + + if model_type in ("qwen2_vl", "qwen2-vl"): + print(f"Loading Qwen2-VL using Qwen2VLForConditionalGeneration") + model = Qwen2VLForConditionalGeneration.from_pretrained( + model_path, + torch_dtype=dt if dt != torch.float32 else None, + device_map=None, + trust_remote_code=True, + ) + print("Loaded model class:", type(model)) + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + # tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + model.to(dev) + model.eval() + lm = LoadedModel(model_type, model_path, model, processor, None, dev, dt) + _loaded[model_path] = lm + return lm + elif model_type in ("internlmxcomposer2"): + model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=dt, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + model.to(dev) + model.eval() + lm = LoadedModel(model_type, model_path, model, None, tokenizer, dev, dt) + _loaded[model_path] = lm + return lm + elif model_type in ("gemma3", "gemma-3", "gemma_3"): + model = Gemma3ForConditionalGeneration.from_pretrained( + model_path, + torch_dtype=dt if dt != torch.float32 else None, + device_map=None, # we move to device below + trust_remote_code=True, + ) + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + + model.to(dev).eval() + lm = LoadedModel(model_type, model_path, model, processor, None, dev, dt) + _loaded[model_path] = lm + return lm + + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + + # Tokenizer is often part of Processor; still try to load explicitly for safety + try: + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=True) + except Exception: + tokenizer = None + + model = None + errors = [] + for candidate in (AutoModel, AutoModelForVision2Seq, AutoModelForCausalLM): + try: + model = candidate.from_pretrained( + model_path, + torch_dtype=dt if dt != torch.float32 else None, + device_map=None, # we move to device manually + trust_remote_code=True + ) + break + except Exception as e: + errors.append(str(e)) + if model is None: + raise RuntimeError(f"Unable to load model {model_path}. Errors: {errors}") + + model.to(dev) + model.eval() + + lm = LoadedModel(model_type, model_path, model, processor, tokenizer, dev, dt) + _loaded[model_path] = lm + return lm + + +def unload_model(model_path: str): + if model_path in _loaded: + lm = _loaded.pop(model_path) + try: + del lm.model + del lm.processor + if lm.tokenizer: + del lm.tokenizer + gc.collect() + if lm.device.type == "cuda": + torch.cuda.empty_cache() + except Exception: + pass + + +# ---------------------------- +# Core inference +# ---------------------------- +def prepare_inputs(lm: LoadedModel, prompt: str, images: List[Image.Image]) -> Dict[str, Any]: + proc = lm.processor + + # If the processor exposes a chat template, use it (covers mllama, qwen2vl, MiniCPM-V2, ...) + if hasattr(proc, "apply_chat_template"): + conversation = [ + {"role": "user", "content": [ + {"type": "image"}, # one placeholder per image; repeat if >1 + {"type": "text", "text": prompt}, + ]} + ] + text_prompt = proc.apply_chat_template(conversation, add_generation_prompt=True) + + encoded = proc( + text=[text_prompt], # list is required by some processors + images=images, # list-of-PIL + padding=True, + return_tensors="pt", + ) + else: + # generic fallback + encoded = proc(text=prompt, images=images, return_tensors="pt") + + # Move to the target device + return {k: v.to(lm.device) if torch.is_tensor(v) else v for k, v in encoded.items()} + + +def generate_text(lm: LoadedModel, inputs: Dict[str, Any], gen: GenParams) -> str: + gen_kwargs = dict( + max_new_tokens=gen.max_new_tokens, + temperature=gen.temperature, + top_p=gen.top_p, + do_sample=gen.do_sample + ) + + with torch.no_grad(), autocast_ctx(lm.device, lm.dtype): + out_ids = lm.model.generate(**inputs, **gen_kwargs) + # Decode + if lm.tokenizer: + return lm.tokenizer.decode(out_ids[0], skip_special_tokens=True) + # Some processors expose a tokenizer inside + if hasattr(lm.processor, "tokenizer") and lm.processor.tokenizer is not None: + return lm.processor.tokenizer.decode(out_ids[0], skip_special_tokens=True) + # Last resort + try: + return lm.model.decode(out_ids[0]) + except Exception: + return "" + + +def time_block(fn, sync, *args, **kwargs) -> Tuple[Any, float]: + start = time.perf_counter() + out = fn(*args, **kwargs) + sync() + dur_ms = (time.perf_counter() - start) * 1000.0 + return out, dur_ms + + +# ---------------------------- +# Routes +# ---------------------------- +@app.get("/health") +def health(): + dev = best_device() + return { + "status": "ok", + "device": str(dev), + "device_name": device_name(dev), + "torch": torch.__version__, + "cuda_available": torch.cuda.is_available(), + "mps_available": hasattr(torch.backends, "mps") and torch.backends.mps.is_available(), + "xpu_available": hasattr(torch, "xpu") and torch.xpu.is_available(), + } + + +@app.get("/info") +def info(): + dev = best_device() + return { + "device": { + "type": dev.type, + "name": device_name(dev), + "total_memory_gb": device_total_mem_gb(dev) + }, + "torch": torch.__version__, + "transformers": __import__("transformers").__version__ + } + + +@app.post("/load_model") +def load_model(req: LoadModelRequest): + lm = resolve_model(req.model_path, req.dtype) + return { + "loaded": lm.model_path, + "device": str(lm.device), + "dtype": str(lm.dtype) + } + + +@app.post("/unload_model") +def unload(req: UnloadModelRequest): + unload_model(req.model_path) + return {"unloaded": req.model} + + +def handle_normal_case(lm: LoadedModel, warmup_runs: int, images: List[str], prompt: str, generation: GenParams): + # Warmup + for _ in range(max(0, warmup_runs)): + try: + _ = generate_text( + lm, + prepare_inputs(lm, "Hello", [Image.new("RGB", (64, 64), color=(128, 128, 128))]), + GenParams(max_new_tokens=8) + ) + synchronize(lm.device) + except Exception: + break + + # Load images + try: + pil_images = [load_image(s) for s in images] + except Exception as e: + raise HTTPException(status_code=400, detail=f"Failed to load images: {e}") + + # Timed steps + synchronize(lm.device) + _, t_pre = time_block(lambda: prepare_inputs(lm, prompt, pil_images), lambda: synchronize(lm.device)) + inputs = _ + + text, t_gen = time_block(lambda: generate_text(lm, inputs, generation), lambda: synchronize(lm.device)) + # for future use, useful for cleaning, parsing, transforming model output + _, t_post = time_block(lambda: None, lambda: synchronize(lm.device)) # placeholder if you add detokenization etc. + return text, t_pre, t_gen, t_post + + +def handle_minicpmv(lm: LoadedModel, image: Image.Image, prompt: str, gen: GenParams): + def generate_text_chat() -> str: + # Prepare msgs in the format expected by model.chat + msgs = [{"role": "user", "content": prompt}] + + # Call the model's built-in chat method + response = lm.model.chat( + image=image, + msgs=msgs, + tokenizer=lm.tokenizer, + sampling=gen.do_sample, + temperature=gen.temperature, + stream=False # Set True if you want streaming later + ) + return response + + # Run chat-based inference + synchronize(lm.device) + text, t_gen = time_block( + lambda: generate_text_chat(), + lambda: synchronize(lm.device) + ) + t_pre, t_post = 0.0, 0.0 # Not needed with chat API + + return text, t_pre, t_gen, t_post + + +def handle_internlm_xcomposer(lm: LoadedModel, + images_pil: List[Image.Image], + prompt: str, + gen: GenParams): + def generate_text_chat(): + # 1️⃣ preprocess every image with the model-supplied CLIP transform + imgs = [lm.model.vis_processor(img.convert("RGB")) for img in images_pil] + batch = torch.stack(imgs).to(lm.device, dtype=lm.dtype) + + # 2️⃣ build the query string – one token per picture + query = (" " * len(images_pil)).strip() + " " + prompt + + # 3️⃣ run chat-style generation + with torch.no_grad(), autocast_ctx(lm.device, lm.dtype): + response, _ = lm.model.chat( + lm.tokenizer, + query=query, + image=batch, + history=[], + do_sample=gen.do_sample, + temperature=gen.temperature, + max_new_tokens=gen.max_new_tokens, + ) + return response + + # Run chat-based inference + synchronize(lm.device) + text, t_gen = time_block( + lambda: generate_text_chat(), + lambda: synchronize(lm.device) + ) + t_pre, t_post = 0.0, 0.0 # Not needed with chat API + + return text, t_pre, t_gen, t_post + + +def handle_qwen2vl(lm: LoadedModel, image_strings: List[str], prompt: str, gen: GenParams): + images = [load_image(s) for s in image_strings] + image = images[0] + + conversation = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": prompt}, + ], + } + ] + + text_prompt = lm.processor.apply_chat_template(conversation, add_generation_prompt=True) + + inputs = lm.processor( + text=[text_prompt], + images=[image], + padding=True, + return_tensors="pt" + ) + inputs = inputs.to(lm.device) + + output_ids = lm.model.generate(**inputs, max_new_tokens=gen.max_new_tokens) + + generated_ids = [ + output_ids[len(input_ids):] + for input_ids, output_ids in zip(inputs.input_ids, output_ids) + ] + output_text = lm.processor.batch_decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + + return output_text, 0.0, 0.0, 0.0 # dummy timing values + + +def handle_gemma3(lm: LoadedModel, image_refs: List[str], prompt: str, gen: GenParams): + img = load_image(image_refs[0]) + + messages = [ + {"role": "user", "content": [ + {"type": "image"}, + {"type": "text", "text": prompt}, + ]} + ] + text_prompt = lm.processor.apply_chat_template(messages, add_generation_prompt=True) + + inputs = lm.processor( + img, + text_prompt, + add_special_tokens=False, + return_tensors="pt" + ).to(lm.device) + + # Run chat-based inference + synchronize(lm.device) + out_ids, t_gen = time_block( + lambda: lm.model.generate(**inputs, max_new_tokens=gen.max_new_tokens), + lambda: synchronize(lm.device) + ) + t_pre, t_post = 0.0, 0.0 # Not needed with chat API + + return lm.processor.decode(out_ids[0], skip_special_tokens=True), t_pre, t_gen, t_post + + +@app.post("/infer", response_model=InferResponse) +def infer(req: InferRequest): + print("infer got") + # Load / reuse model + lm = resolve_model(req.model_path, req.dtype) + print(f"{lm.model_type=}") + if lm.model_type == 'minicpmv': + text, t_pre, t_gen, t_post = handle_minicpmv(lm, load_image(req.images[0]), req.prompt, req.generation) + elif lm.model_type in ("qwen2vl", "qwen2-vl", "qwen2_vl"): + text, t_pre, t_gen, t_post = handle_qwen2vl(lm, req.images, req.prompt, req.generation) + elif lm.model_type == "internlmxcomposer2": + # Load images + try: + pil_images = [load_image(s) for s in req.images] + except Exception as e: + raise HTTPException(status_code=400, detail=f"Failed to load images: {e}") + text, t_pre, t_gen, t_post = handle_internlm_xcomposer(lm, pil_images, req.prompt, req.generation) + else: + text, t_pre, t_gen, t_post = handle_normal_case(lm, req.warmup_runs, req.images, req.prompt, req.generation) + timings = { + "preprocess": t_pre, + "generate": t_gen, + "postprocess": t_post, + "e2e": t_pre + t_gen + t_post + } + + return InferResponse( + output_text=text, + timings_ms=timings, + device={ + "type": lm.device.type, + "name": device_name(lm.device), + "total_memory_gb": device_total_mem_gb(lm.device) + }, + model_info={ + "name": lm.model_path, + "precision": str(lm.dtype).replace("torch.", ""), + "framework": "transformers" + } + ) + +# Entry +# Run: uvicorn server:app --host 0.0.0.0 --port 8000