144 lines
4.6 KiB
Python
144 lines
4.6 KiB
Python
import os
|
|
import io
|
|
import base64
|
|
from typing import Any, Dict, List, Optional
|
|
import soundfile as sf
|
|
import torch
|
|
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel
|
|
from PIL import Image
|
|
from starlette.responses import JSONResponse
|
|
from transformers import Qwen3OmniMoeForConditionalGeneration, Qwen3OmniMoeProcessor
|
|
from qwen_omni_utils import process_mm_info
|
|
import time
|
|
|
|
os.environ["PYTORCH_MLU_ALLOC_CONF"] = "expandable_segments:True"
|
|
|
|
MODEL_PATH = "/mnt/models/Qwen3-Omni-30B-A3B-Instruct"
|
|
|
|
print("Loading model...")
|
|
model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
|
|
MODEL_PATH,
|
|
trust_remote_code=True,
|
|
dtype=torch.float16,
|
|
device_map="auto",
|
|
)
|
|
model.disable_talker()
|
|
processor = Qwen3OmniMoeProcessor.from_pretrained(MODEL_PATH)
|
|
print("✅ Model & processor loaded.")
|
|
|
|
# =========================
|
|
# 初始化 FastAPI
|
|
# =========================
|
|
app = FastAPI(title="Qwen3-Omni vllm-format wrapper")
|
|
|
|
# 在模型加载成功后定义健康检查接口
|
|
@app.get("/health")
|
|
def health_check():
|
|
return JSONResponse(status_code=200, content={"status": "ok"})
|
|
|
|
# =========================
|
|
# 请求 schema
|
|
# =========================
|
|
class GenerateRequest(BaseModel):
|
|
messages: List[Dict[str, Any]]
|
|
speaker: Optional[str] = "Ethan"
|
|
use_audio_in_video: Optional[bool] = True
|
|
resize_image: Optional[bool] = True
|
|
image_size: Optional[int] = 448
|
|
|
|
# =========================
|
|
# 工具函数
|
|
# =========================
|
|
def _decode_data_uri_image(data_uri: str) -> Image.Image:
|
|
if not data_uri.startswith("data:"):
|
|
raise ValueError("Not a data URI")
|
|
header, b64 = data_uri.split(",", 1)
|
|
decoded = base64.b64decode(b64)
|
|
return Image.open(io.BytesIO(decoded))
|
|
|
|
# =========================
|
|
# 主推理接口
|
|
# =========================
|
|
@app.post("/generate")
|
|
def generate(req: GenerateRequest):
|
|
messages = req.messages
|
|
if not messages or not isinstance(messages, list):
|
|
raise HTTPException(status_code=400, detail="messages must be a non-empty list")
|
|
|
|
conversation = []
|
|
for m in messages:
|
|
role = m.get("role", "user")
|
|
raw_content = m.get("content", [])
|
|
content_list = []
|
|
for c in raw_content:
|
|
ctype = c.get("type")
|
|
if ctype == "image_url":
|
|
url = c.get("image_url", {}).get("url") if isinstance(c.get("image_url"), dict) else c.get("image_url")
|
|
content_list.append({"type": "image", "image": url})
|
|
elif ctype == "text":
|
|
content_list.append({"type": "text", "text": c.get("text", "")})
|
|
else:
|
|
content_list.append(c)
|
|
conversation.append({"role": role, "content": content_list})
|
|
|
|
text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
|
|
|
|
images = []
|
|
for msg in conversation:
|
|
for c in msg["content"]:
|
|
if c.get("type") == "image":
|
|
img_ref = c.get("image")
|
|
if isinstance(img_ref, str) and img_ref.startswith("data:"):
|
|
try:
|
|
pil = _decode_data_uri_image(img_ref)
|
|
pil = pil.resize((224, 224))
|
|
images.append(pil)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=f"failed to decode data URI image: {e}")
|
|
|
|
inputs = processor(
|
|
text=text,
|
|
images=images,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
use_audio_in_video=False,
|
|
).to(model.device, dtype=torch.float16)
|
|
|
|
try:
|
|
text_ids, audio_out = model.generate(
|
|
**inputs,
|
|
speaker=req.speaker,
|
|
thinker_return_dict_in_generate=True,
|
|
use_audio_in_video=False,
|
|
max_new_tokens=128,
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"model.generate failed: {e}")
|
|
|
|
try:
|
|
out_texts = processor.batch_decode(
|
|
text_ids.sequences[:, inputs["input_ids"].shape[1]:],
|
|
skip_special_tokens=True,
|
|
clean_up_tokenization_spaces=False,
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"processor.batch_decode failed: {e}")
|
|
|
|
result = {
|
|
"id": "chatcmpl",
|
|
"object": "chat.completion",
|
|
"created": int(time.time()),
|
|
"model": "Qwen3-Omni",
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": {"role": "assistant", "content": out_texts},
|
|
"finish_reason": "stop"
|
|
}
|
|
]
|
|
}
|
|
|
|
return JSONResponse(result)
|
|
|