605 lines
19 KiB
Python
605 lines
19 KiB
Python
import base64
|
||
import gc
|
||
import io
|
||
import os
|
||
import time
|
||
from typing import List, Optional, Dict, Any, Tuple
|
||
|
||
import torch
|
||
from PIL import Image
|
||
from fastapi import FastAPI, HTTPException
|
||
from pydantic import BaseModel
|
||
from transformers import (
|
||
AutoProcessor,
|
||
AutoTokenizer,
|
||
AutoConfig,
|
||
AutoModelForCausalLM,
|
||
AutoModelForVision2Seq, AutoModel
|
||
)
|
||
|
||
|
||
try:
|
||
from transformers import (Qwen2VLForConditionalGeneration, Gemma3ForConditionalGeneration)
|
||
except ImportError:
|
||
pass
|
||
|
||
|
||
app = FastAPI(title="Unified VLM API (Transformers)")
|
||
|
||
|
||
# ----------------------------
|
||
# Device selection & sync
|
||
# ----------------------------
|
||
def best_device() -> torch.device:
|
||
# CUDA covers NVIDIA and AMD ROCm builds under torch.cuda
|
||
if torch.cuda.is_available():
|
||
return torch.device("cuda")
|
||
# Intel oneAPI/XPU (ipex)
|
||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||
return torch.device("xpu")
|
||
# Apple MPS
|
||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||
return torch.device("mps")
|
||
return torch.device("cpu")
|
||
|
||
|
||
def device_name(device: torch.device) -> str:
|
||
if device.type == "cuda":
|
||
try:
|
||
return torch.cuda.get_device_name(0)
|
||
except Exception:
|
||
return "CUDA device"
|
||
if device.type == "xpu":
|
||
return "Intel XPU"
|
||
if device.type == "mps":
|
||
return "Apple MPS"
|
||
return "CPU"
|
||
|
||
|
||
def device_total_mem_gb(device: torch.device) -> Optional[float]:
|
||
try:
|
||
if device.type == "cuda":
|
||
return torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
|
||
# For others, memory reporting varies
|
||
return None
|
||
except Exception:
|
||
return None
|
||
|
||
|
||
def synchronize(device: torch.device):
|
||
# Ensure accurate wall times for GPU work
|
||
try:
|
||
if device.type == "cuda":
|
||
torch.cuda.synchronize()
|
||
elif device.type == "xpu" and hasattr(torch, "xpu"):
|
||
torch.xpu.synchronize()
|
||
elif device.type == "mps" and hasattr(torch, "mps"):
|
||
torch.mps.synchronize()
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
# ----------------------------
|
||
# Model registry
|
||
# ----------------------------
|
||
class LoadedModel:
|
||
def __init__(self, model_type: str, model_path: str, model, processor, tokenizer,
|
||
device: torch.device, dtype: torch.dtype):
|
||
self.model_type = model_type
|
||
self.model_path = model_path
|
||
self.model = model
|
||
self.processor = processor
|
||
self.tokenizer = tokenizer
|
||
self.device = device
|
||
self.dtype = dtype
|
||
|
||
|
||
_loaded: Dict[str, LoadedModel] = {}
|
||
|
||
|
||
# ----------------------------
|
||
# IO helpers
|
||
# ----------------------------
|
||
def load_image(ref: str) -> Image.Image:
|
||
# Accept http(s) URLs, local paths, or base64 data URLs
|
||
if ref.startswith("http://") or ref.startswith("https://"):
|
||
# Rely on HF file utilities only if you want offline; here use requests lazily
|
||
import requests
|
||
r = requests.get(ref, timeout=30)
|
||
r.raise_for_status()
|
||
return Image.open(io.BytesIO(r.content)).convert("RGB")
|
||
if os.path.exists(ref):
|
||
return Image.open(ref).convert("RGB")
|
||
# Base64
|
||
if ref.startswith("data:image"):
|
||
header, b64 = ref.split(",", 1)
|
||
return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
|
||
# Raw base64
|
||
try:
|
||
return Image.open(io.BytesIO(base64.b64decode(ref))).convert("RGB")
|
||
except Exception:
|
||
raise ValueError(f"Unsupported image reference: {ref[:80]}...")
|
||
|
||
|
||
def pick_dtype(req_dtype: str, device: torch.device) -> torch.dtype:
|
||
if req_dtype == "float16":
|
||
return torch.float16
|
||
if req_dtype == "bfloat16":
|
||
return torch.bfloat16
|
||
if req_dtype == "float32":
|
||
return torch.float32
|
||
# auto
|
||
if device.type in ("cuda", "xpu"):
|
||
# bfloat16 works broadly on modern GPUs; fall back to float16 for older CUDA
|
||
try:
|
||
return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||
except Exception:
|
||
return torch.float16
|
||
if device.type == "mps":
|
||
return torch.float16
|
||
return torch.float32
|
||
|
||
|
||
def autocast_ctx(device: torch.device, dtype: torch.dtype):
|
||
if device.type == "cpu":
|
||
return torch.autocast(device_type="cpu", dtype=dtype)
|
||
if device.type == "cuda":
|
||
return torch.autocast(device_type="cuda", dtype=dtype)
|
||
if device.type == "xpu":
|
||
return torch.autocast(device_type="xpu", dtype=dtype)
|
||
if device.type == "mps":
|
||
return torch.autocast(device_type="mps", dtype=dtype)
|
||
# fallback no-op
|
||
from contextlib import nullcontext
|
||
return nullcontext()
|
||
|
||
|
||
# ----------------------------
|
||
# Requests/Responses
|
||
# ----------------------------
|
||
class GenParams(BaseModel):
|
||
max_new_tokens: int = 128
|
||
temperature: float = 0.0
|
||
top_p: float = 1.0
|
||
do_sample: bool = False
|
||
|
||
|
||
class InferRequest(BaseModel):
|
||
model_path: str
|
||
prompt: str
|
||
images: List[str]
|
||
generation: GenParams = GenParams()
|
||
dtype: str = "auto" # "auto"|"float16"|"bfloat16"|"float32"
|
||
warmup_runs: int = 1
|
||
measure_token_times: bool = False
|
||
|
||
|
||
class InferResponse(BaseModel):
|
||
output_text: str
|
||
timings_ms: Dict[str, float]
|
||
device: Dict[str, Any]
|
||
model_info: Dict[str, Any]
|
||
|
||
|
||
class LoadModelRequest(BaseModel):
|
||
model_path: str
|
||
dtype: str = "auto"
|
||
|
||
|
||
class UnloadModelRequest(BaseModel):
|
||
model_path: str
|
||
|
||
|
||
# ----------------------------
|
||
# Model loading
|
||
# ----------------------------
|
||
def resolve_model(model_path: str, dtype_str: str) -> LoadedModel:
|
||
if model_path in _loaded:
|
||
return _loaded[model_path]
|
||
|
||
dev = best_device()
|
||
dt = pick_dtype(dtype_str, dev)
|
||
|
||
cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||
model_type = cfg.model_type
|
||
print(f"model type detected: {model_type}, device: {dev}, dt: {dt}")
|
||
|
||
if model_type in ("qwen2_vl", "qwen2-vl"):
|
||
print(f"Loading Qwen2-VL using Qwen2VLForConditionalGeneration")
|
||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||
model_path,
|
||
torch_dtype=dt if dt != torch.float32 else None,
|
||
device_map=None,
|
||
trust_remote_code=True,
|
||
)
|
||
print("Loaded model class:", type(model))
|
||
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
||
# tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||
model.to(dev)
|
||
model.eval()
|
||
lm = LoadedModel(model_type, model_path, model, processor, None, dev, dt)
|
||
_loaded[model_path] = lm
|
||
return lm
|
||
elif model_type in ("internlmxcomposer2"):
|
||
dt = torch.float16
|
||
print(f"dt change to {dt}")
|
||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=dt, trust_remote_code=True, device_map='auto')
|
||
model = model.eval()
|
||
lm = LoadedModel(model_type, model_path, model, None, tokenizer, dev, dt)
|
||
_loaded[model_path] = lm
|
||
return lm
|
||
elif model_type in ("gemma3", "gemma-3", "gemma_3"):
|
||
model = Gemma3ForConditionalGeneration.from_pretrained(
|
||
model_path,
|
||
torch_dtype=dt if dt != torch.float32 else None,
|
||
device_map=None, # we move to device below
|
||
trust_remote_code=True,
|
||
)
|
||
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
||
|
||
model.to(dev).eval()
|
||
lm = LoadedModel(model_type, model_path, model, processor, None, dev, dt)
|
||
_loaded[model_path] = lm
|
||
return lm
|
||
|
||
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
||
|
||
# Tokenizer is often part of Processor; still try to load explicitly for safety
|
||
try:
|
||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=True)
|
||
except Exception:
|
||
tokenizer = None
|
||
|
||
model = None
|
||
errors = []
|
||
for candidate in (AutoModel, AutoModelForVision2Seq, AutoModelForCausalLM):
|
||
try:
|
||
model = candidate.from_pretrained(
|
||
model_path,
|
||
torch_dtype=dt if dt != torch.float32 else None,
|
||
device_map=None, # we move to device manually
|
||
trust_remote_code=True
|
||
)
|
||
break
|
||
except Exception as e:
|
||
errors.append(str(e))
|
||
if model is None:
|
||
raise RuntimeError(f"Unable to load model {model_path}. Errors: {errors}")
|
||
|
||
model.to(dev)
|
||
model.eval()
|
||
|
||
lm = LoadedModel(model_type, model_path, model, processor, tokenizer, dev, dt)
|
||
_loaded[model_path] = lm
|
||
return lm
|
||
|
||
|
||
def unload_model(model_path: str):
|
||
if model_path in _loaded:
|
||
lm = _loaded.pop(model_path)
|
||
try:
|
||
del lm.model
|
||
del lm.processor
|
||
if lm.tokenizer:
|
||
del lm.tokenizer
|
||
gc.collect()
|
||
if lm.device.type == "cuda":
|
||
torch.cuda.empty_cache()
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
# ----------------------------
|
||
# Core inference
|
||
# ----------------------------
|
||
def prepare_inputs(lm: LoadedModel, prompt: str, images: List[Image.Image]) -> Dict[str, Any]:
|
||
proc = lm.processor
|
||
|
||
# If the processor exposes a chat template, use it (covers mllama, qwen2vl, MiniCPM-V2, ...)
|
||
if hasattr(proc, "apply_chat_template"):
|
||
conversation = [
|
||
{"role": "user", "content": [
|
||
{"type": "image"}, # one placeholder per image; repeat if >1
|
||
{"type": "text", "text": prompt},
|
||
]}
|
||
]
|
||
text_prompt = proc.apply_chat_template(conversation, add_generation_prompt=True)
|
||
|
||
encoded = proc(
|
||
text=[text_prompt], # list is required by some processors
|
||
images=images, # list-of-PIL
|
||
padding=True,
|
||
return_tensors="pt",
|
||
)
|
||
else:
|
||
# generic fallback
|
||
encoded = proc(text=prompt, images=images, return_tensors="pt")
|
||
|
||
# Move to the target device
|
||
return {k: v.to(lm.device) if torch.is_tensor(v) else v for k, v in encoded.items()}
|
||
|
||
|
||
def generate_text(lm: LoadedModel, inputs: Dict[str, Any], gen: GenParams) -> str:
|
||
gen_kwargs = dict(
|
||
max_new_tokens=gen.max_new_tokens,
|
||
temperature=gen.temperature,
|
||
top_p=gen.top_p,
|
||
do_sample=gen.do_sample
|
||
)
|
||
|
||
with torch.no_grad(), autocast_ctx(lm.device, lm.dtype):
|
||
out_ids = lm.model.generate(**inputs, **gen_kwargs)
|
||
# Decode
|
||
if lm.tokenizer:
|
||
return lm.tokenizer.decode(out_ids[0], skip_special_tokens=True)
|
||
# Some processors expose a tokenizer inside
|
||
if hasattr(lm.processor, "tokenizer") and lm.processor.tokenizer is not None:
|
||
return lm.processor.tokenizer.decode(out_ids[0], skip_special_tokens=True)
|
||
# Last resort
|
||
try:
|
||
return lm.model.decode(out_ids[0])
|
||
except Exception:
|
||
return "<decode_failed>"
|
||
|
||
|
||
def time_block(fn, sync, *args, **kwargs) -> Tuple[Any, float]:
|
||
start = time.perf_counter()
|
||
out = fn(*args, **kwargs)
|
||
sync()
|
||
dur_ms = (time.perf_counter() - start) * 1000.0
|
||
return out, dur_ms
|
||
|
||
|
||
# ----------------------------
|
||
# Routes
|
||
# ----------------------------
|
||
@app.get("/health")
|
||
def health():
|
||
dev = best_device()
|
||
return {
|
||
"status": "ok",
|
||
"device": str(dev),
|
||
"device_name": device_name(dev),
|
||
"torch": torch.__version__,
|
||
"cuda_available": torch.cuda.is_available(),
|
||
"mps_available": hasattr(torch.backends, "mps") and torch.backends.mps.is_available(),
|
||
"xpu_available": hasattr(torch, "xpu") and torch.xpu.is_available(),
|
||
}
|
||
|
||
|
||
@app.get("/info")
|
||
def info():
|
||
dev = best_device()
|
||
return {
|
||
"device": {
|
||
"type": dev.type,
|
||
"name": device_name(dev),
|
||
"total_memory_gb": device_total_mem_gb(dev)
|
||
},
|
||
"torch": torch.__version__,
|
||
"transformers": __import__("transformers").__version__
|
||
}
|
||
|
||
|
||
@app.post("/load_model")
|
||
def load_model(req: LoadModelRequest):
|
||
lm = resolve_model(req.model_path, req.dtype)
|
||
print(f"model with path {req.model_path} loaded!")
|
||
return {
|
||
"loaded": lm.model_path,
|
||
"device": str(lm.device),
|
||
"dtype": str(lm.dtype)
|
||
}
|
||
|
||
|
||
@app.post("/unload_model")
|
||
def unload(req: UnloadModelRequest):
|
||
unload_model(req.model_path)
|
||
return {"unloaded": req.model}
|
||
|
||
|
||
def handle_normal_case(lm: LoadedModel, warmup_runs: int, images: List[str], prompt: str, generation: GenParams):
|
||
# Warmup
|
||
for _ in range(max(0, warmup_runs)):
|
||
try:
|
||
_ = generate_text(
|
||
lm,
|
||
prepare_inputs(lm, "Hello", [Image.new("RGB", (64, 64), color=(128, 128, 128))]),
|
||
GenParams(max_new_tokens=8)
|
||
)
|
||
synchronize(lm.device)
|
||
except Exception:
|
||
break
|
||
|
||
# Load images
|
||
try:
|
||
pil_images = [load_image(s) for s in images]
|
||
except Exception as e:
|
||
raise HTTPException(status_code=400, detail=f"Failed to load images: {e}")
|
||
|
||
# Timed steps
|
||
synchronize(lm.device)
|
||
_, t_pre = time_block(lambda: prepare_inputs(lm, prompt, pil_images), lambda: synchronize(lm.device))
|
||
inputs = _
|
||
|
||
text, t_gen = time_block(lambda: generate_text(lm, inputs, generation), lambda: synchronize(lm.device))
|
||
# for future use, useful for cleaning, parsing, transforming model output
|
||
_, t_post = time_block(lambda: None, lambda: synchronize(lm.device)) # placeholder if you add detokenization etc.
|
||
return text, t_pre, t_gen, t_post
|
||
|
||
|
||
def handle_minicpmv(lm: LoadedModel, image: Image.Image, prompt: str, gen: GenParams):
|
||
def generate_text_chat() -> str:
|
||
# Prepare msgs in the format expected by model.chat
|
||
msgs = [{"role": "user", "content": prompt}]
|
||
|
||
# Call the model's built-in chat method
|
||
response = lm.model.chat(
|
||
image=image,
|
||
msgs=msgs,
|
||
tokenizer=lm.tokenizer,
|
||
sampling=gen.do_sample,
|
||
temperature=gen.temperature,
|
||
stream=False # Set True if you want streaming later
|
||
)
|
||
return response
|
||
|
||
# Run chat-based inference
|
||
synchronize(lm.device)
|
||
text, t_gen = time_block(
|
||
lambda: generate_text_chat(),
|
||
lambda: synchronize(lm.device)
|
||
)
|
||
t_pre, t_post = 0.0, 0.0 # Not needed with chat API
|
||
|
||
return text, t_pre, t_gen, t_post
|
||
|
||
|
||
def handle_internlm_xcomposer(lm: LoadedModel,
|
||
images_pil: List[Image.Image],
|
||
prompt: str,
|
||
gen: GenParams):
|
||
def generate_text_chat():
|
||
# 1️⃣ preprocess every image with the model-supplied CLIP transform
|
||
imgs = [lm.model.vis_processor(img.convert("RGB")) for img in images_pil]
|
||
batch = torch.stack(imgs).to(lm.device, dtype=lm.dtype)
|
||
|
||
# 2️⃣ build the query string – one <ImageHere> token per picture
|
||
query = ("<ImageHere> " * len(images_pil)).strip() + " " + prompt
|
||
|
||
# 3️⃣ run chat-style generation
|
||
with torch.no_grad(), autocast_ctx(lm.device, lm.dtype):
|
||
response, _ = lm.model.chat(
|
||
lm.tokenizer,
|
||
query=query,
|
||
image=batch,
|
||
history=[],
|
||
do_sample=gen.do_sample,
|
||
temperature=gen.temperature,
|
||
max_new_tokens=gen.max_new_tokens,
|
||
)
|
||
return response
|
||
|
||
# Run chat-based inference
|
||
synchronize(lm.device)
|
||
text, t_gen = time_block(
|
||
lambda: generate_text_chat(),
|
||
lambda: synchronize(lm.device)
|
||
)
|
||
t_pre, t_post = 0.0, 0.0 # Not needed with chat API
|
||
|
||
return text, t_pre, t_gen, t_post
|
||
|
||
|
||
def handle_qwen2vl(lm: LoadedModel, image_strings: List[str], prompt: str, gen: GenParams):
|
||
images = [load_image(s) for s in image_strings]
|
||
image = images[0]
|
||
|
||
conversation = [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "image"},
|
||
{"type": "text", "text": prompt},
|
||
],
|
||
}
|
||
]
|
||
|
||
text_prompt = lm.processor.apply_chat_template(conversation, add_generation_prompt=True)
|
||
|
||
inputs = lm.processor(
|
||
text=[text_prompt],
|
||
images=[image],
|
||
padding=True,
|
||
return_tensors="pt"
|
||
)
|
||
inputs = inputs.to(lm.device)
|
||
|
||
output_ids = lm.model.generate(**inputs, max_new_tokens=gen.max_new_tokens)
|
||
|
||
generated_ids = [
|
||
output_ids[len(input_ids):]
|
||
for input_ids, output_ids in zip(inputs.input_ids, output_ids)
|
||
]
|
||
output_text = lm.processor.batch_decode(
|
||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||
)
|
||
|
||
return output_text, 0.0, 0.0, 0.0 # dummy timing values
|
||
|
||
|
||
def handle_gemma3(lm: LoadedModel, image_refs: List[str], prompt: str, gen: GenParams):
|
||
img = load_image(image_refs[0])
|
||
|
||
messages = [
|
||
{"role": "user", "content": [
|
||
{"type": "image"},
|
||
{"type": "text", "text": prompt},
|
||
]}
|
||
]
|
||
text_prompt = lm.processor.apply_chat_template(messages, add_generation_prompt=True)
|
||
|
||
inputs = lm.processor(
|
||
img,
|
||
text_prompt,
|
||
add_special_tokens=False,
|
||
return_tensors="pt"
|
||
).to(lm.device)
|
||
|
||
# Run chat-based inference
|
||
synchronize(lm.device)
|
||
out_ids, t_gen = time_block(
|
||
lambda: lm.model.generate(**inputs, max_new_tokens=gen.max_new_tokens),
|
||
lambda: synchronize(lm.device)
|
||
)
|
||
t_pre, t_post = 0.0, 0.0 # Not needed with chat API
|
||
|
||
return lm.processor.decode(out_ids[0], skip_special_tokens=True), t_pre, t_gen, t_post
|
||
|
||
|
||
@app.post("/infer", response_model=InferResponse)
|
||
def infer(req: InferRequest):
|
||
print("infer got")
|
||
# Load / reuse model
|
||
lm = resolve_model(req.model_path, req.dtype)
|
||
print(f"{lm.model_type=}")
|
||
if lm.model_type == 'minicpmv':
|
||
text, t_pre, t_gen, t_post = handle_minicpmv(lm, load_image(req.images[0]), req.prompt, req.generation)
|
||
elif lm.model_type in ("qwen2vl", "qwen2-vl", "qwen2_vl"):
|
||
text, t_pre, t_gen, t_post = handle_qwen2vl(lm, req.images, req.prompt, req.generation)
|
||
elif lm.model_type == "internlmxcomposer2":
|
||
# Load images
|
||
try:
|
||
pil_images = [load_image(s) for s in req.images]
|
||
except Exception as e:
|
||
raise HTTPException(status_code=400, detail=f"Failed to load images: {e}")
|
||
text, t_pre, t_gen, t_post = handle_internlm_xcomposer(lm, pil_images, req.prompt, req.generation)
|
||
else:
|
||
text, t_pre, t_gen, t_post = handle_normal_case(lm, req.warmup_runs, req.images, req.prompt, req.generation)
|
||
timings = {
|
||
"preprocess": t_pre,
|
||
"generate": t_gen,
|
||
"postprocess": t_post,
|
||
"e2e": t_pre + t_gen + t_post
|
||
}
|
||
|
||
return InferResponse(
|
||
output_text=text,
|
||
timings_ms=timings,
|
||
device={
|
||
"type": lm.device.type,
|
||
"name": device_name(lm.device),
|
||
"total_memory_gb": device_total_mem_gb(lm.device)
|
||
},
|
||
model_info={
|
||
"name": lm.model_path,
|
||
"precision": str(lm.dtype).replace("torch.", ""),
|
||
"framework": "transformers"
|
||
}
|
||
)
|
||
|
||
# Entry
|
||
# Run: uvicorn server:app --host 0.0.0.0 --port 8000
|
||
|