feature: add
This commit is contained in:
50
README.md
50
README.md
@@ -1,2 +1,52 @@
|
||||
# enginex-mlu370-vl
|
||||
|
||||
# 寒武纪 mlu370 视觉理解多模态
|
||||
该模型测试框架在寒武纪mlu370 (X8/X4)加速卡上,基于Transfomer框架,适配了 gemma-3-4b-it、MiniCPM-Llama3-V-2_5 、MiniCPM_V_2_6 这3个模型。
|
||||
|
||||
* Gemma 3-4B‑IT 是 Google 发布的 Gemma 3 系列中参数量为 4 B 的轻量 multimodal 模型,支持图文输入、128 K 长上下文、多语种(140+ 语言),专为嵌入设备快速部署设计
|
||||
* MiniCPM‑Llama3‑V 2.5 是 openbmb 的 8 B multimodal 模型,基于 SigLip‑400M 与 Llama3-8B-Instruct 构建,在 OCR 能力、多语言支持、部署效率等方面表现优秀,整体性能达到 GPT‑4V 级别
|
||||
* MiniCPM‑V 2.6 是 MiniCPM‑V 系列中最新且最强大的 8 B 参数模型,具备更优的单图、多图与视频理解能力、卓越 OCR 效果、低 hallucination 率,并支持端侧设备(如 iPad)实时视频理解
|
||||
|
||||
|
||||
## 模型测试服务原理
|
||||
尽管对于视觉多模态理解没有一个业界统一的API协议标准,但我们也可以基于目前比较流行的Transfomer框架**适配**各类视觉理解多模态模型。
|
||||
为了让我们的测试框架更通用一些,我们基于Transfomer框架对于不同类型的模型系列adpat了一层,方便对外提供http服务。
|
||||
|
||||
目前,测试框架要求用户首先测试时指定需要测试的模型的地址mount到本地文件系统中,如`/model`,之后通过unvicorn拉起服务。
|
||||
|
||||
测试过程中,外围测试环境,会首先调用“加载模型接口”:
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:10086/load_model \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model_path":"/model","dtype":"auto"}'
|
||||
```
|
||||
|
||||
|
||||
## 模型测试服务请求示例
|
||||
准备好用于测试的图片和问题,通过infer接口获取推理结果:
|
||||
|
||||
```bash
|
||||
base64 -w 0 demo.jpeg | \
|
||||
jq -Rs --arg mp "/model" --arg prompt "Describe the picture" \
|
||||
'{model_path: $mp, prompt: $prompt, images: ["data:image/jpeg;base64," + .], generation: {max_new_tokens: 50, temperature: 0.7}}' | \
|
||||
curl -X POST "http://localhost:10086/infer" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d @-
|
||||
```
|
||||
|
||||
以上,图片为`demo.jpeg`,问题为`Describe the picture`,可根据需要相应替换。
|
||||
## 如何使用视觉理解多模态测试框架
|
||||
由于VLM相关的模型一般需要较大的存储空间,为了更好的测试效率,需要提前下载好模型相关文件,k8s集群可以mount的持久化介质(比如cephFS),之后提交测试时指定模型存放的地址。
|
||||
|
||||
`docker-images/server.py`代码实现了一个接收图片和问题并返回回答文本和统计延迟信息的VLM HTTP 服务。测试框架集成了现成的可用的镜像`harbor.4pd.io/mic-llm-x/combricon-mlu370x8_test_wyq:1.0.0`(`server.py`作为入口),可以用于本地端(如有GPU卡)测试。
|
||||
|
||||
作为测试对比,我们也提供a100相对应的镜像 `harbor.4pd.io/hardcore-tech/a100-3.2.1-x86-ubuntu20.04-py3.10-poc-vlm-infer:0.0.1`
|
||||
## 寒武纪mlu370-X8上视觉理解多模态模型运行测试结果
|
||||
在mlu370-X8上对部分视觉理解多模态模型进行适配,测试方式为在 Nvidia A100 和 mlu370-X8 加速卡上对10个图片相关问题回答,获取运行时间
|
||||
|
||||
| 模型名称 | 模型类型 | 适配状态 | mlu370-X8运行时间/s | Nvidia A100运行时间/s |
|
||||
| ---------- | ---------------------- | -------- | ----------------- | --------------------- |
|
||||
| Gemma 3-4B‑IT | Gemma 3 系列 | 成功 | 8.8261 | 5.5634 |
|
||||
| MiniCPM‑Llama3‑V 2.5 | openbmb 8 B multimodal | 成功 | 14.7240 | 8.2024 |
|
||||
| MiniCPM‑V 2.6 | MiniCPM‑V 系列 | 成功 | 9.5498 | 4.3531 |
|
||||
8
docker-images/mlu370-x8.dockerfile
Normal file
8
docker-images/mlu370-x8.dockerfile
Normal file
@@ -0,0 +1,8 @@
|
||||
# generate harbor.4pd.io/hardcore-tech/mr100-3.2.1-x86-ubuntu20.04-py3.10-poc-vlm-infer:0.0.1
|
||||
FROM harbor.4pd.io/mic-llm-x/combricon-mlu370x4-base:v0.2.0-tgiv1.4.3-btv0.6.0-pt2.1-x86_64-ubuntu22.04-py310
|
||||
ENV HF_ENDPOINT=https://hf-mirror.com
|
||||
RUN pip install transformers==4.50.0 uvicorn\[standard\] fastapi
|
||||
WORKDIR /app
|
||||
COPY server.py /app/server.py
|
||||
EXPOSE 8000
|
||||
CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
7
docker-images/nvidia-a100.dockerfile
Normal file
7
docker-images/nvidia-a100.dockerfile
Normal file
@@ -0,0 +1,7 @@
|
||||
FROM harbor.4pd.io/hardcore-tech/vllm/vllm-openai:v0.8.5.post1
|
||||
ENV HF_ENDPOINT=https://hf-mirror.com
|
||||
RUN pip install transformers==4.50.0
|
||||
WORKDIR /app
|
||||
COPY server.py /app/server.py
|
||||
EXPOSE 8000
|
||||
CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
630
docker-images/server.py
Normal file
630
docker-images/server.py
Normal file
@@ -0,0 +1,630 @@
|
||||
import base64
|
||||
import gc
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
|
||||
import torch
|
||||
try:
|
||||
import torch_mlu
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from PIL import Image
|
||||
from fastapi import FastAPI, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForVision2Seq, AutoModel
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from transformers import (Qwen2VLForConditionalGeneration, Gemma3ForConditionalGeneration)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
app = FastAPI(title="Unified VLM API (Transformers)")
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Device selection & sync
|
||||
# ----------------------------
|
||||
def best_device() -> torch.device:
|
||||
# CUDA covers NVIDIA and AMD ROCm builds under torch.cuda
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
# Intel oneAPI/XPU (ipex)
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
return torch.device("xpu")
|
||||
# Apple MPS
|
||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
# cambricon MLU
|
||||
if hasattr(torch, "mlu") and torch.mlu.is_available():
|
||||
return torch.device("mlu")
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
def device_name(device: torch.device) -> str:
|
||||
if device.type == "cuda":
|
||||
try:
|
||||
return torch.cuda.get_device_name(0)
|
||||
except Exception:
|
||||
return "CUDA device"
|
||||
if device.type == "xpu":
|
||||
return "Intel XPU"
|
||||
if device.type == "mps":
|
||||
return "Apple MPS"
|
||||
if device.type == "mlu":
|
||||
return "cambricon MLU"
|
||||
return "CPU"
|
||||
|
||||
|
||||
def device_total_mem_gb(device: torch.device) -> Optional[float]:
|
||||
try:
|
||||
if device.type == "cuda":
|
||||
return torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
|
||||
# For others, memory reporting varies
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def synchronize(device: torch.device):
|
||||
# Ensure accurate wall times for GPU work
|
||||
try:
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
elif device.type == "xpu" and hasattr(torch, "xpu"):
|
||||
torch.xpu.synchronize()
|
||||
elif device.type == "mps" and hasattr(torch, "mps"):
|
||||
torch.mps.synchronize()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Model registry
|
||||
# ----------------------------
|
||||
class LoadedModel:
|
||||
def __init__(self, model_type: str, model_path: str, model, processor, tokenizer,
|
||||
device: torch.device, dtype: torch.dtype):
|
||||
self.model_type = model_type
|
||||
self.model_path = model_path
|
||||
self.model = model
|
||||
self.processor = processor
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
|
||||
_loaded: Dict[str, LoadedModel] = {}
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# IO helpers
|
||||
# ----------------------------
|
||||
def load_image(ref: str) -> Image.Image:
|
||||
# Accept http(s) URLs, local paths, or base64 data URLs
|
||||
if ref.startswith("http://") or ref.startswith("https://"):
|
||||
# Rely on HF file utilities only if you want offline; here use requests lazily
|
||||
import requests
|
||||
r = requests.get(ref, timeout=30)
|
||||
r.raise_for_status()
|
||||
return Image.open(io.BytesIO(r.content)).convert("RGB")
|
||||
if os.path.exists(ref):
|
||||
return Image.open(ref).convert("RGB")
|
||||
# Base64
|
||||
if ref.startswith("data:image"):
|
||||
header, b64 = ref.split(",", 1)
|
||||
return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
|
||||
# Raw base64
|
||||
try:
|
||||
return Image.open(io.BytesIO(base64.b64decode(ref))).convert("RGB")
|
||||
except Exception:
|
||||
raise ValueError(f"Unsupported image reference: {ref[:80]}...")
|
||||
|
||||
|
||||
def pick_dtype(req_dtype: str, device: torch.device) -> torch.dtype:
|
||||
if req_dtype == "float16":
|
||||
return torch.float16
|
||||
if req_dtype == "bfloat16":
|
||||
return torch.bfloat16
|
||||
if req_dtype == "float32":
|
||||
return torch.float32
|
||||
# auto
|
||||
if device.type in ("cuda", "xpu"):
|
||||
# bfloat16 works broadly on modern GPUs; fall back to float16 for older CUDA
|
||||
try:
|
||||
return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||
except Exception:
|
||||
return torch.float16
|
||||
if device.type == "mps" or device.type == "mlu":
|
||||
return torch.float16
|
||||
return torch.float32
|
||||
|
||||
|
||||
def autocast_ctx(device: torch.device, dtype: torch.dtype):
|
||||
if device.type == "cpu":
|
||||
return torch.autocast(device_type="cpu", dtype=dtype)
|
||||
if device.type == "cuda":
|
||||
return torch.autocast(device_type="cuda", dtype=dtype)
|
||||
if device.type == "xpu":
|
||||
return torch.autocast(device_type="xpu", dtype=dtype)
|
||||
if device.type == "mps":
|
||||
return torch.autocast(device_type="mps", dtype=dtype)
|
||||
# fallback no-op
|
||||
from contextlib import nullcontext
|
||||
return nullcontext()
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Requests/Responses
|
||||
# ----------------------------
|
||||
class GenParams(BaseModel):
|
||||
max_new_tokens: int = 128
|
||||
temperature: float = 0.0
|
||||
top_p: float = 1.0
|
||||
do_sample: bool = False
|
||||
|
||||
|
||||
class InferRequest(BaseModel):
|
||||
model_path: str
|
||||
prompt: str
|
||||
images: List[str]
|
||||
generation: GenParams = GenParams()
|
||||
dtype: str = "auto" # "auto"|"float16"|"bfloat16"|"float32"
|
||||
warmup_runs: int = 1
|
||||
measure_token_times: bool = False
|
||||
|
||||
|
||||
class InferResponse(BaseModel):
|
||||
output_text: str
|
||||
timings_ms: Dict[str, float]
|
||||
device: Dict[str, Any]
|
||||
model_info: Dict[str, Any]
|
||||
|
||||
|
||||
class LoadModelRequest(BaseModel):
|
||||
model_path: str
|
||||
dtype: str = "auto"
|
||||
|
||||
|
||||
class UnloadModelRequest(BaseModel):
|
||||
model_path: str
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Model loading
|
||||
# ----------------------------
|
||||
def resolve_model(model_path: str, dtype_str: str) -> LoadedModel:
|
||||
if model_path in _loaded:
|
||||
return _loaded[model_path]
|
||||
|
||||
dev = best_device()
|
||||
dt = pick_dtype(dtype_str, dev)
|
||||
|
||||
cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
model_type = cfg.model_type
|
||||
print(f"model type detected: {model_type}, device: {dev}, dt: {dt}")
|
||||
|
||||
if model_type in ("qwen2_vl", "qwen2-vl"):
|
||||
print(f"Loading Qwen2-VL using Qwen2VLForConditionalGeneration")
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=dt if dt != torch.float32 else None,
|
||||
device_map=None,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
print("Loaded model class:", type(model))
|
||||
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
||||
# tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model.to(dev)
|
||||
model.eval()
|
||||
lm = LoadedModel(model_type, model_path, model, processor, None, dev, dt)
|
||||
_loaded[model_path] = lm
|
||||
return lm
|
||||
elif model_type in ("internlmxcomposer2"):
|
||||
dt = torch.float16
|
||||
print(f"dt change to {dt}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=dt, trust_remote_code=True, device_map='auto')
|
||||
model = model.eval()
|
||||
lm = LoadedModel(model_type, model_path, model, None, tokenizer, dev, dt)
|
||||
_loaded[model_path] = lm
|
||||
return lm
|
||||
elif model_type in ("gemma3", "gemma-3", "gemma_3"):
|
||||
model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=dt if dt != torch.float32 else None,
|
||||
device_map=None, # we move to device below
|
||||
trust_remote_code=True,
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
model.to(dev).eval()
|
||||
lm = LoadedModel(model_type, model_path, model, processor, None, dev, dt)
|
||||
_loaded[model_path] = lm
|
||||
return lm
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
# Tokenizer is often part of Processor; still try to load explicitly for safety
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=True)
|
||||
except Exception:
|
||||
tokenizer = None
|
||||
|
||||
model = None
|
||||
errors = []
|
||||
for candidate in (AutoModel, AutoModelForVision2Seq, AutoModelForCausalLM):
|
||||
try:
|
||||
model = candidate.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=dt if dt != torch.float32 else None,
|
||||
device_map=None, # we move to device manually
|
||||
trust_remote_code=True
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
errors.append(str(e))
|
||||
if model is None:
|
||||
raise RuntimeError(f"Unable to load model {model_path}. Errors: {errors}")
|
||||
|
||||
model.to(dev)
|
||||
model.eval()
|
||||
|
||||
lm = LoadedModel(model_type, model_path, model, processor, tokenizer, dev, dt)
|
||||
_loaded[model_path] = lm
|
||||
return lm
|
||||
|
||||
|
||||
def unload_model(model_path: str):
|
||||
if model_path in _loaded:
|
||||
lm = _loaded.pop(model_path)
|
||||
try:
|
||||
del lm.model
|
||||
del lm.processor
|
||||
if lm.tokenizer:
|
||||
del lm.tokenizer
|
||||
gc.collect()
|
||||
if lm.device.type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Core inference
|
||||
# ----------------------------
|
||||
def prepare_inputs(lm: LoadedModel, prompt: str, images: List[Image.Image]) -> Dict[str, Any]:
|
||||
proc = lm.processor
|
||||
|
||||
# If the processor exposes a chat template, use it (covers mllama, qwen2vl, MiniCPM-V2, ...)
|
||||
if hasattr(proc, "apply_chat_template"):
|
||||
conversation = [
|
||||
{"role": "user", "content": [
|
||||
{"type": "image"}, # one placeholder per image; repeat if >1
|
||||
{"type": "text", "text": prompt},
|
||||
]}
|
||||
]
|
||||
text_prompt = proc.apply_chat_template(conversation, add_generation_prompt=True)
|
||||
|
||||
encoded = proc(
|
||||
text=[text_prompt], # list is required by some processors
|
||||
images=images, # list-of-PIL
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
else:
|
||||
# generic fallback
|
||||
encoded = proc(text=prompt, images=images, return_tensors="pt")
|
||||
|
||||
# Move to the target device
|
||||
return {k: v.to(lm.device) if torch.is_tensor(v) else v for k, v in encoded.items()}
|
||||
|
||||
|
||||
def generate_text(lm: LoadedModel, inputs: Dict[str, Any], gen: GenParams) -> str:
|
||||
gen_kwargs = dict(
|
||||
max_new_tokens=gen.max_new_tokens,
|
||||
temperature=gen.temperature,
|
||||
top_p=gen.top_p,
|
||||
do_sample=gen.do_sample
|
||||
)
|
||||
|
||||
with torch.no_grad(), autocast_ctx(lm.device, lm.dtype):
|
||||
out_ids = lm.model.generate(**inputs, **gen_kwargs)
|
||||
# Decode
|
||||
if lm.tokenizer:
|
||||
return lm.tokenizer.decode(out_ids[0], skip_special_tokens=True)
|
||||
# Some processors expose a tokenizer inside
|
||||
if hasattr(lm.processor, "tokenizer") and lm.processor.tokenizer is not None:
|
||||
return lm.processor.tokenizer.decode(out_ids[0], skip_special_tokens=True)
|
||||
# Last resort
|
||||
try:
|
||||
return lm.model.decode(out_ids[0])
|
||||
except Exception:
|
||||
return "<decode_failed>"
|
||||
|
||||
|
||||
def time_block(fn, sync, *args, **kwargs) -> Tuple[Any, float]:
|
||||
start = time.perf_counter()
|
||||
out = fn(*args, **kwargs)
|
||||
sync()
|
||||
dur_ms = (time.perf_counter() - start) * 1000.0
|
||||
return out, dur_ms
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Routes
|
||||
# ----------------------------
|
||||
@app.get("/health")
|
||||
def health():
|
||||
dev = best_device()
|
||||
return {
|
||||
"status": "ok",
|
||||
"device": str(dev),
|
||||
"device_name": device_name(dev),
|
||||
"torch": torch.__version__,
|
||||
"cuda_available": torch.cuda.is_available(),
|
||||
"mps_available": hasattr(torch.backends, "mps") and torch.backends.mps.is_available(),
|
||||
"xpu_available": hasattr(torch, "xpu") and torch.xpu.is_available(),
|
||||
}
|
||||
|
||||
|
||||
@app.get("/info")
|
||||
def info():
|
||||
dev = best_device()
|
||||
return {
|
||||
"device": {
|
||||
"type": dev.type,
|
||||
"name": device_name(dev),
|
||||
"total_memory_gb": device_total_mem_gb(dev)
|
||||
},
|
||||
"torch": torch.__version__,
|
||||
"transformers": __import__("transformers").__version__
|
||||
}
|
||||
|
||||
|
||||
@app.post("/load_model")
|
||||
def load_model(req: LoadModelRequest):
|
||||
lm = resolve_model(req.model_path, req.dtype)
|
||||
print(f"model with path {req.model_path} loaded!")
|
||||
return {
|
||||
"loaded": lm.model_path,
|
||||
"device": str(lm.device),
|
||||
"dtype": str(lm.dtype)
|
||||
}
|
||||
|
||||
|
||||
@app.get("/model_status")
|
||||
def model_status(model_path: str = Query(..., description="Path of the model to check")):
|
||||
lm = _loaded.get(model_path)
|
||||
if lm:
|
||||
return {
|
||||
"loaded": True,
|
||||
"model_path": lm.model_path,
|
||||
"device": str(lm.device),
|
||||
"dtype": str(lm.dtype)
|
||||
}
|
||||
return {
|
||||
"loaded": False,
|
||||
"model_path": model_path
|
||||
}
|
||||
|
||||
|
||||
@app.post("/unload_model")
|
||||
def unload(req: UnloadModelRequest):
|
||||
unload_model(req.model_path)
|
||||
return {"unloaded": req.model}
|
||||
|
||||
|
||||
def handle_normal_case(lm: LoadedModel, warmup_runs: int, images: List[str], prompt: str, generation: GenParams):
|
||||
# Warmup
|
||||
for _ in range(max(0, warmup_runs)):
|
||||
try:
|
||||
_ = generate_text(
|
||||
lm,
|
||||
prepare_inputs(lm, "Hello", [Image.new("RGB", (64, 64), color=(128, 128, 128))]),
|
||||
GenParams(max_new_tokens=8)
|
||||
)
|
||||
synchronize(lm.device)
|
||||
except Exception:
|
||||
break
|
||||
|
||||
# Load images
|
||||
try:
|
||||
pil_images = [load_image(s) for s in images]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Failed to load images: {e}")
|
||||
|
||||
# Timed steps
|
||||
synchronize(lm.device)
|
||||
_, t_pre = time_block(lambda: prepare_inputs(lm, prompt, pil_images), lambda: synchronize(lm.device))
|
||||
inputs = _
|
||||
|
||||
text, t_gen = time_block(lambda: generate_text(lm, inputs, generation), lambda: synchronize(lm.device))
|
||||
# for future use, useful for cleaning, parsing, transforming model output
|
||||
_, t_post = time_block(lambda: None, lambda: synchronize(lm.device)) # placeholder if you add detokenization etc.
|
||||
return text, t_pre, t_gen, t_post
|
||||
|
||||
|
||||
def handle_minicpmv(lm: LoadedModel, image: Image.Image, prompt: str, gen: GenParams):
|
||||
def generate_text_chat() -> str:
|
||||
# Prepare msgs in the format expected by model.chat
|
||||
msgs = [{"role": "user", "content": prompt}]
|
||||
|
||||
# Call the model's built-in chat method
|
||||
response = lm.model.chat(
|
||||
image=image,
|
||||
msgs=msgs,
|
||||
tokenizer=lm.tokenizer,
|
||||
sampling=gen.do_sample,
|
||||
temperature=gen.temperature,
|
||||
stream=False # Set True if you want streaming later
|
||||
)
|
||||
return response
|
||||
|
||||
# Run chat-based inference
|
||||
synchronize(lm.device)
|
||||
text, t_gen = time_block(
|
||||
lambda: generate_text_chat(),
|
||||
lambda: synchronize(lm.device)
|
||||
)
|
||||
t_pre, t_post = 0.0, 0.0 # Not needed with chat API
|
||||
|
||||
return text, t_pre, t_gen, t_post
|
||||
|
||||
|
||||
def handle_internlm_xcomposer(lm: LoadedModel,
|
||||
images_pil: List[Image.Image],
|
||||
prompt: str,
|
||||
gen: GenParams):
|
||||
def generate_text_chat():
|
||||
# 1️⃣ preprocess every image with the model-supplied CLIP transform
|
||||
imgs = [lm.model.vis_processor(img.convert("RGB")) for img in images_pil]
|
||||
batch = torch.stack(imgs).to(lm.device, dtype=lm.dtype)
|
||||
|
||||
# 2️⃣ build the query string – one <ImageHere> token per picture
|
||||
query = ("<ImageHere> " * len(images_pil)).strip() + " " + prompt
|
||||
|
||||
# 3️⃣ run chat-style generation
|
||||
with torch.no_grad(), autocast_ctx(lm.device, lm.dtype):
|
||||
response, _ = lm.model.chat(
|
||||
lm.tokenizer,
|
||||
query=query,
|
||||
image=batch,
|
||||
history=[],
|
||||
do_sample=gen.do_sample,
|
||||
temperature=gen.temperature,
|
||||
max_new_tokens=gen.max_new_tokens,
|
||||
)
|
||||
return response
|
||||
|
||||
# Run chat-based inference
|
||||
synchronize(lm.device)
|
||||
text, t_gen = time_block(
|
||||
lambda: generate_text_chat(),
|
||||
lambda: synchronize(lm.device)
|
||||
)
|
||||
t_pre, t_post = 0.0, 0.0 # Not needed with chat API
|
||||
|
||||
return text, t_pre, t_gen, t_post
|
||||
|
||||
|
||||
def handle_qwen2vl(lm: LoadedModel, image_strings: List[str], prompt: str, gen: GenParams):
|
||||
images = [load_image(s) for s in image_strings]
|
||||
image = images[0]
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
text_prompt = lm.processor.apply_chat_template(conversation, add_generation_prompt=True)
|
||||
|
||||
inputs = lm.processor(
|
||||
text=[text_prompt],
|
||||
images=[image],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
inputs = inputs.to(lm.device)
|
||||
|
||||
output_ids = lm.model.generate(**inputs, max_new_tokens=gen.max_new_tokens)
|
||||
|
||||
generated_ids = [
|
||||
output_ids[len(input_ids):]
|
||||
for input_ids, output_ids in zip(inputs.input_ids, output_ids)
|
||||
]
|
||||
output_text = lm.processor.batch_decode(
|
||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
|
||||
return output_text, 0.0, 0.0, 0.0 # dummy timing values
|
||||
|
||||
|
||||
def handle_gemma3(lm: LoadedModel, image_refs: List[str], prompt: str, gen: GenParams):
|
||||
img = load_image(image_refs[0])
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": prompt},
|
||||
]}
|
||||
]
|
||||
text_prompt = lm.processor.apply_chat_template(messages, add_generation_prompt=True)
|
||||
|
||||
inputs = lm.processor(
|
||||
img,
|
||||
text_prompt,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt"
|
||||
).to(lm.device)
|
||||
|
||||
# Run chat-based inference
|
||||
synchronize(lm.device)
|
||||
out_ids, t_gen = time_block(
|
||||
lambda: lm.model.generate(**inputs, max_new_tokens=gen.max_new_tokens),
|
||||
lambda: synchronize(lm.device)
|
||||
)
|
||||
t_pre, t_post = 0.0, 0.0 # Not needed with chat API
|
||||
|
||||
return lm.processor.decode(out_ids[0], skip_special_tokens=True), t_pre, t_gen, t_post
|
||||
|
||||
|
||||
@app.post("/infer", response_model=InferResponse)
|
||||
def infer(req: InferRequest):
|
||||
print("infer got")
|
||||
# Load / reuse model
|
||||
lm = resolve_model(req.model_path, req.dtype)
|
||||
print(f"{lm.model_type=}")
|
||||
if lm.model_type == 'minicpmv':
|
||||
text, t_pre, t_gen, t_post = handle_minicpmv(lm, load_image(req.images[0]), req.prompt, req.generation)
|
||||
elif lm.model_type in ("qwen2vl", "qwen2-vl", "qwen2_vl"):
|
||||
text, t_pre, t_gen, t_post = handle_qwen2vl(lm, req.images, req.prompt, req.generation)
|
||||
elif lm.model_type == "internlmxcomposer2":
|
||||
# Load images
|
||||
try:
|
||||
pil_images = [load_image(s) for s in req.images]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Failed to load images: {e}")
|
||||
text, t_pre, t_gen, t_post = handle_internlm_xcomposer(lm, pil_images, req.prompt, req.generation)
|
||||
else:
|
||||
text, t_pre, t_gen, t_post = handle_normal_case(lm, req.warmup_runs, req.images, req.prompt, req.generation)
|
||||
timings = {
|
||||
"preprocess": t_pre,
|
||||
"generate": t_gen,
|
||||
"postprocess": t_post,
|
||||
"e2e": t_pre + t_gen + t_post
|
||||
}
|
||||
|
||||
return InferResponse(
|
||||
output_text=text,
|
||||
timings_ms=timings,
|
||||
device={
|
||||
"type": lm.device.type,
|
||||
"name": device_name(lm.device),
|
||||
"total_memory_gb": device_total_mem_gb(lm.device)
|
||||
},
|
||||
model_info={
|
||||
"name": lm.model_path,
|
||||
"precision": str(lm.dtype).replace("torch.", ""),
|
||||
"framework": "transformers"
|
||||
}
|
||||
)
|
||||
|
||||
# Entry
|
||||
# Run: uvicorn server:app --host 0.0.0.0 --port 8000
|
||||
|
||||
Reference in New Issue
Block a user