Files
2025-08-25 14:54:08 +08:00

605 lines
19 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import base64
import gc
import io
import os
import time
from typing import List, Optional, Dict, Any, Tuple
import torch
from PIL import Image
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import (
AutoProcessor,
AutoTokenizer,
AutoConfig,
AutoModelForCausalLM,
AutoModelForVision2Seq, AutoModel
)
try:
from transformers import (Qwen2VLForConditionalGeneration, Gemma3ForConditionalGeneration)
except ImportError:
pass
app = FastAPI(title="Unified VLM API (Transformers)")
# ----------------------------
# Device selection & sync
# ----------------------------
def best_device() -> torch.device:
# CUDA covers NVIDIA and AMD ROCm builds under torch.cuda
if torch.cuda.is_available():
return torch.device("cuda")
# Intel oneAPI/XPU (ipex)
if hasattr(torch, "xpu") and torch.xpu.is_available():
return torch.device("xpu")
# Apple MPS
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
def device_name(device: torch.device) -> str:
if device.type == "cuda":
try:
return torch.cuda.get_device_name(0)
except Exception:
return "CUDA device"
if device.type == "xpu":
return "Intel XPU"
if device.type == "mps":
return "Apple MPS"
return "CPU"
def device_total_mem_gb(device: torch.device) -> Optional[float]:
try:
if device.type == "cuda":
return torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
# For others, memory reporting varies
return None
except Exception:
return None
def synchronize(device: torch.device):
# Ensure accurate wall times for GPU work
try:
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "xpu" and hasattr(torch, "xpu"):
torch.xpu.synchronize()
elif device.type == "mps" and hasattr(torch, "mps"):
torch.mps.synchronize()
except Exception:
pass
# ----------------------------
# Model registry
# ----------------------------
class LoadedModel:
def __init__(self, model_type: str, model_path: str, model, processor, tokenizer,
device: torch.device, dtype: torch.dtype):
self.model_type = model_type
self.model_path = model_path
self.model = model
self.processor = processor
self.tokenizer = tokenizer
self.device = device
self.dtype = dtype
_loaded: Dict[str, LoadedModel] = {}
# ----------------------------
# IO helpers
# ----------------------------
def load_image(ref: str) -> Image.Image:
# Accept http(s) URLs, local paths, or base64 data URLs
if ref.startswith("http://") or ref.startswith("https://"):
# Rely on HF file utilities only if you want offline; here use requests lazily
import requests
r = requests.get(ref, timeout=30)
r.raise_for_status()
return Image.open(io.BytesIO(r.content)).convert("RGB")
if os.path.exists(ref):
return Image.open(ref).convert("RGB")
# Base64
if ref.startswith("data:image"):
header, b64 = ref.split(",", 1)
return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
# Raw base64
try:
return Image.open(io.BytesIO(base64.b64decode(ref))).convert("RGB")
except Exception:
raise ValueError(f"Unsupported image reference: {ref[:80]}...")
def pick_dtype(req_dtype: str, device: torch.device) -> torch.dtype:
if req_dtype == "float16":
return torch.float16
if req_dtype == "bfloat16":
return torch.bfloat16
if req_dtype == "float32":
return torch.float32
# auto
if device.type in ("cuda", "xpu"):
# bfloat16 works broadly on modern GPUs; fall back to float16 for older CUDA
try:
return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
except Exception:
return torch.float16
if device.type == "mps":
return torch.float16
return torch.float32
def autocast_ctx(device: torch.device, dtype: torch.dtype):
if device.type == "cpu":
return torch.autocast(device_type="cpu", dtype=dtype)
if device.type == "cuda":
return torch.autocast(device_type="cuda", dtype=dtype)
if device.type == "xpu":
return torch.autocast(device_type="xpu", dtype=dtype)
if device.type == "mps":
return torch.autocast(device_type="mps", dtype=dtype)
# fallback no-op
from contextlib import nullcontext
return nullcontext()
# ----------------------------
# Requests/Responses
# ----------------------------
class GenParams(BaseModel):
max_new_tokens: int = 128
temperature: float = 0.0
top_p: float = 1.0
do_sample: bool = False
class InferRequest(BaseModel):
model_path: str
prompt: str
images: List[str]
generation: GenParams = GenParams()
dtype: str = "auto" # "auto"|"float16"|"bfloat16"|"float32"
warmup_runs: int = 1
measure_token_times: bool = False
class InferResponse(BaseModel):
output_text: str
timings_ms: Dict[str, float]
device: Dict[str, Any]
model_info: Dict[str, Any]
class LoadModelRequest(BaseModel):
model_path: str
dtype: str = "auto"
class UnloadModelRequest(BaseModel):
model_path: str
# ----------------------------
# Model loading
# ----------------------------
def resolve_model(model_path: str, dtype_str: str) -> LoadedModel:
if model_path in _loaded:
return _loaded[model_path]
dev = best_device()
dt = pick_dtype(dtype_str, dev)
cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
model_type = cfg.model_type
print(f"model type detected: {model_type}, device: {dev}, dt: {dt}")
if model_type in ("qwen2_vl", "qwen2-vl"):
print(f"Loading Qwen2-VL using Qwen2VLForConditionalGeneration")
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=dt if dt != torch.float32 else None,
device_map=None,
trust_remote_code=True,
)
print("Loaded model class:", type(model))
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
# tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model.to(dev)
model.eval()
lm = LoadedModel(model_type, model_path, model, processor, None, dev, dt)
_loaded[model_path] = lm
return lm
elif model_type in ("internlmxcomposer2"):
dt = torch.float16
print(f"dt change to {dt}")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=dt, trust_remote_code=True, device_map='auto')
model = model.eval()
lm = LoadedModel(model_type, model_path, model, None, tokenizer, dev, dt)
_loaded[model_path] = lm
return lm
elif model_type in ("gemma3", "gemma-3", "gemma_3"):
model = Gemma3ForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=dt if dt != torch.float32 else None,
device_map=None, # we move to device below
trust_remote_code=True,
)
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
model.to(dev).eval()
lm = LoadedModel(model_type, model_path, model, processor, None, dev, dt)
_loaded[model_path] = lm
return lm
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
# Tokenizer is often part of Processor; still try to load explicitly for safety
try:
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=True)
except Exception:
tokenizer = None
model = None
errors = []
for candidate in (AutoModel, AutoModelForVision2Seq, AutoModelForCausalLM):
try:
model = candidate.from_pretrained(
model_path,
torch_dtype=dt if dt != torch.float32 else None,
device_map=None, # we move to device manually
trust_remote_code=True
)
break
except Exception as e:
errors.append(str(e))
if model is None:
raise RuntimeError(f"Unable to load model {model_path}. Errors: {errors}")
model.to(dev)
model.eval()
lm = LoadedModel(model_type, model_path, model, processor, tokenizer, dev, dt)
_loaded[model_path] = lm
return lm
def unload_model(model_path: str):
if model_path in _loaded:
lm = _loaded.pop(model_path)
try:
del lm.model
del lm.processor
if lm.tokenizer:
del lm.tokenizer
gc.collect()
if lm.device.type == "cuda":
torch.cuda.empty_cache()
except Exception:
pass
# ----------------------------
# Core inference
# ----------------------------
def prepare_inputs(lm: LoadedModel, prompt: str, images: List[Image.Image]) -> Dict[str, Any]:
proc = lm.processor
# If the processor exposes a chat template, use it (covers mllama, qwen2vl, MiniCPM-V2, ...)
if hasattr(proc, "apply_chat_template"):
conversation = [
{"role": "user", "content": [
{"type": "image"}, # one placeholder per image; repeat if >1
{"type": "text", "text": prompt},
]}
]
text_prompt = proc.apply_chat_template(conversation, add_generation_prompt=True)
encoded = proc(
text=[text_prompt], # list is required by some processors
images=images, # list-of-PIL
padding=True,
return_tensors="pt",
)
else:
# generic fallback
encoded = proc(text=prompt, images=images, return_tensors="pt")
# Move to the target device
return {k: v.to(lm.device) if torch.is_tensor(v) else v for k, v in encoded.items()}
def generate_text(lm: LoadedModel, inputs: Dict[str, Any], gen: GenParams) -> str:
gen_kwargs = dict(
max_new_tokens=gen.max_new_tokens,
temperature=gen.temperature,
top_p=gen.top_p,
do_sample=gen.do_sample
)
with torch.no_grad(), autocast_ctx(lm.device, lm.dtype):
out_ids = lm.model.generate(**inputs, **gen_kwargs)
# Decode
if lm.tokenizer:
return lm.tokenizer.decode(out_ids[0], skip_special_tokens=True)
# Some processors expose a tokenizer inside
if hasattr(lm.processor, "tokenizer") and lm.processor.tokenizer is not None:
return lm.processor.tokenizer.decode(out_ids[0], skip_special_tokens=True)
# Last resort
try:
return lm.model.decode(out_ids[0])
except Exception:
return "<decode_failed>"
def time_block(fn, sync, *args, **kwargs) -> Tuple[Any, float]:
start = time.perf_counter()
out = fn(*args, **kwargs)
sync()
dur_ms = (time.perf_counter() - start) * 1000.0
return out, dur_ms
# ----------------------------
# Routes
# ----------------------------
@app.get("/health")
def health():
dev = best_device()
return {
"status": "ok",
"device": str(dev),
"device_name": device_name(dev),
"torch": torch.__version__,
"cuda_available": torch.cuda.is_available(),
"mps_available": hasattr(torch.backends, "mps") and torch.backends.mps.is_available(),
"xpu_available": hasattr(torch, "xpu") and torch.xpu.is_available(),
}
@app.get("/info")
def info():
dev = best_device()
return {
"device": {
"type": dev.type,
"name": device_name(dev),
"total_memory_gb": device_total_mem_gb(dev)
},
"torch": torch.__version__,
"transformers": __import__("transformers").__version__
}
@app.post("/load_model")
def load_model(req: LoadModelRequest):
lm = resolve_model(req.model_path, req.dtype)
print(f"model with path {req.model_path} loaded!")
return {
"loaded": lm.model_path,
"device": str(lm.device),
"dtype": str(lm.dtype)
}
@app.post("/unload_model")
def unload(req: UnloadModelRequest):
unload_model(req.model_path)
return {"unloaded": req.model}
def handle_normal_case(lm: LoadedModel, warmup_runs: int, images: List[str], prompt: str, generation: GenParams):
# Warmup
for _ in range(max(0, warmup_runs)):
try:
_ = generate_text(
lm,
prepare_inputs(lm, "Hello", [Image.new("RGB", (64, 64), color=(128, 128, 128))]),
GenParams(max_new_tokens=8)
)
synchronize(lm.device)
except Exception:
break
# Load images
try:
pil_images = [load_image(s) for s in images]
except Exception as e:
raise HTTPException(status_code=400, detail=f"Failed to load images: {e}")
# Timed steps
synchronize(lm.device)
_, t_pre = time_block(lambda: prepare_inputs(lm, prompt, pil_images), lambda: synchronize(lm.device))
inputs = _
text, t_gen = time_block(lambda: generate_text(lm, inputs, generation), lambda: synchronize(lm.device))
# for future use, useful for cleaning, parsing, transforming model output
_, t_post = time_block(lambda: None, lambda: synchronize(lm.device)) # placeholder if you add detokenization etc.
return text, t_pre, t_gen, t_post
def handle_minicpmv(lm: LoadedModel, image: Image.Image, prompt: str, gen: GenParams):
def generate_text_chat() -> str:
# Prepare msgs in the format expected by model.chat
msgs = [{"role": "user", "content": prompt}]
# Call the model's built-in chat method
response = lm.model.chat(
image=image,
msgs=msgs,
tokenizer=lm.tokenizer,
sampling=gen.do_sample,
temperature=gen.temperature,
stream=False # Set True if you want streaming later
)
return response
# Run chat-based inference
synchronize(lm.device)
text, t_gen = time_block(
lambda: generate_text_chat(),
lambda: synchronize(lm.device)
)
t_pre, t_post = 0.0, 0.0 # Not needed with chat API
return text, t_pre, t_gen, t_post
def handle_internlm_xcomposer(lm: LoadedModel,
images_pil: List[Image.Image],
prompt: str,
gen: GenParams):
def generate_text_chat():
# 1⃣ preprocess every image with the model-supplied CLIP transform
imgs = [lm.model.vis_processor(img.convert("RGB")) for img in images_pil]
batch = torch.stack(imgs).to(lm.device, dtype=lm.dtype)
# 2⃣ build the query string one <ImageHere> token per picture
query = ("<ImageHere> " * len(images_pil)).strip() + " " + prompt
# 3⃣ run chat-style generation
with torch.no_grad(), autocast_ctx(lm.device, lm.dtype):
response, _ = lm.model.chat(
lm.tokenizer,
query=query,
image=batch,
history=[],
do_sample=gen.do_sample,
temperature=gen.temperature,
max_new_tokens=gen.max_new_tokens,
)
return response
# Run chat-based inference
synchronize(lm.device)
text, t_gen = time_block(
lambda: generate_text_chat(),
lambda: synchronize(lm.device)
)
t_pre, t_post = 0.0, 0.0 # Not needed with chat API
return text, t_pre, t_gen, t_post
def handle_qwen2vl(lm: LoadedModel, image_strings: List[str], prompt: str, gen: GenParams):
images = [load_image(s) for s in image_strings]
image = images[0]
conversation = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": prompt},
],
}
]
text_prompt = lm.processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = lm.processor(
text=[text_prompt],
images=[image],
padding=True,
return_tensors="pt"
)
inputs = inputs.to(lm.device)
output_ids = lm.model.generate(**inputs, max_new_tokens=gen.max_new_tokens)
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(inputs.input_ids, output_ids)
]
output_text = lm.processor.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
return output_text, 0.0, 0.0, 0.0 # dummy timing values
def handle_gemma3(lm: LoadedModel, image_refs: List[str], prompt: str, gen: GenParams):
img = load_image(image_refs[0])
messages = [
{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": prompt},
]}
]
text_prompt = lm.processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = lm.processor(
img,
text_prompt,
add_special_tokens=False,
return_tensors="pt"
).to(lm.device)
# Run chat-based inference
synchronize(lm.device)
out_ids, t_gen = time_block(
lambda: lm.model.generate(**inputs, max_new_tokens=gen.max_new_tokens),
lambda: synchronize(lm.device)
)
t_pre, t_post = 0.0, 0.0 # Not needed with chat API
return lm.processor.decode(out_ids[0], skip_special_tokens=True), t_pre, t_gen, t_post
@app.post("/infer", response_model=InferResponse)
def infer(req: InferRequest):
print("infer got")
# Load / reuse model
lm = resolve_model(req.model_path, req.dtype)
print(f"{lm.model_type=}")
if lm.model_type == 'minicpmv':
text, t_pre, t_gen, t_post = handle_minicpmv(lm, load_image(req.images[0]), req.prompt, req.generation)
elif lm.model_type in ("qwen2vl", "qwen2-vl", "qwen2_vl"):
text, t_pre, t_gen, t_post = handle_qwen2vl(lm, req.images, req.prompt, req.generation)
elif lm.model_type == "internlmxcomposer2":
# Load images
try:
pil_images = [load_image(s) for s in req.images]
except Exception as e:
raise HTTPException(status_code=400, detail=f"Failed to load images: {e}")
text, t_pre, t_gen, t_post = handle_internlm_xcomposer(lm, pil_images, req.prompt, req.generation)
else:
text, t_pre, t_gen, t_post = handle_normal_case(lm, req.warmup_runs, req.images, req.prompt, req.generation)
timings = {
"preprocess": t_pre,
"generate": t_gen,
"postprocess": t_post,
"e2e": t_pre + t_gen + t_post
}
return InferResponse(
output_text=text,
timings_ms=timings,
device={
"type": lm.device.type,
"name": device_name(lm.device),
"total_memory_gb": device_total_mem_gb(lm.device)
},
model_info={
"name": lm.model_path,
"precision": str(lm.dtype).replace("torch.", ""),
"framework": "transformers"
}
)
# Entry
# Run: uvicorn server:app --host 0.0.0.0 --port 8000