Files
2025-10-09 16:47:16 +08:00

144 lines
4.6 KiB
Python

import os
import io
import base64
from typing import Any, Dict, List, Optional
import soundfile as sf
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from PIL import Image
from starlette.responses import JSONResponse
from transformers import Qwen3OmniMoeForConditionalGeneration, Qwen3OmniMoeProcessor
from qwen_omni_utils import process_mm_info
import time
os.environ["PYTORCH_MLU_ALLOC_CONF"] = "expandable_segments:True"
MODEL_PATH = "/mnt/models/Qwen3-Omni-30B-A3B-Instruct"
print("Loading model...")
model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
dtype=torch.float16,
device_map="auto",
)
model.disable_talker()
processor = Qwen3OmniMoeProcessor.from_pretrained(MODEL_PATH)
print("✅ Model & processor loaded.")
# =========================
# 初始化 FastAPI
# =========================
app = FastAPI(title="Qwen3-Omni vllm-format wrapper")
# 在模型加载成功后定义健康检查接口
@app.get("/health")
def health_check():
return JSONResponse(status_code=200, content={"status": "ok"})
# =========================
# 请求 schema
# =========================
class GenerateRequest(BaseModel):
messages: List[Dict[str, Any]]
speaker: Optional[str] = "Ethan"
use_audio_in_video: Optional[bool] = True
resize_image: Optional[bool] = True
image_size: Optional[int] = 448
# =========================
# 工具函数
# =========================
def _decode_data_uri_image(data_uri: str) -> Image.Image:
if not data_uri.startswith("data:"):
raise ValueError("Not a data URI")
header, b64 = data_uri.split(",", 1)
decoded = base64.b64decode(b64)
return Image.open(io.BytesIO(decoded))
# =========================
# 主推理接口
# =========================
@app.post("/generate")
def generate(req: GenerateRequest):
messages = req.messages
if not messages or not isinstance(messages, list):
raise HTTPException(status_code=400, detail="messages must be a non-empty list")
conversation = []
for m in messages:
role = m.get("role", "user")
raw_content = m.get("content", [])
content_list = []
for c in raw_content:
ctype = c.get("type")
if ctype == "image_url":
url = c.get("image_url", {}).get("url") if isinstance(c.get("image_url"), dict) else c.get("image_url")
content_list.append({"type": "image", "image": url})
elif ctype == "text":
content_list.append({"type": "text", "text": c.get("text", "")})
else:
content_list.append(c)
conversation.append({"role": role, "content": content_list})
text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
images = []
for msg in conversation:
for c in msg["content"]:
if c.get("type") == "image":
img_ref = c.get("image")
if isinstance(img_ref, str) and img_ref.startswith("data:"):
try:
pil = _decode_data_uri_image(img_ref)
pil = pil.resize((224, 224))
images.append(pil)
except Exception as e:
raise HTTPException(status_code=400, detail=f"failed to decode data URI image: {e}")
inputs = processor(
text=text,
images=images,
return_tensors="pt",
padding=True,
use_audio_in_video=False,
).to(model.device, dtype=torch.float16)
try:
text_ids, audio_out = model.generate(
**inputs,
speaker=req.speaker,
thinker_return_dict_in_generate=True,
use_audio_in_video=False,
max_new_tokens=128,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"model.generate failed: {e}")
try:
out_texts = processor.batch_decode(
text_ids.sequences[:, inputs["input_ids"].shape[1]:],
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"processor.batch_decode failed: {e}")
result = {
"id": "chatcmpl",
"object": "chat.completion",
"created": int(time.time()),
"model": "Qwen3-Omni",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": out_texts},
"finish_reason": "stop"
}
]
}
return JSONResponse(result)