import os import io import base64 from typing import Any, Dict, List, Optional import soundfile as sf import torch from fastapi import FastAPI, HTTPException from pydantic import BaseModel from PIL import Image from starlette.responses import JSONResponse from transformers import Qwen3OmniMoeForConditionalGeneration, Qwen3OmniMoeProcessor from qwen_omni_utils import process_mm_info import time os.environ["PYTORCH_MLU_ALLOC_CONF"] = "expandable_segments:True" MODEL_PATH = "/mnt/models/Qwen3-Omni-30B-A3B-Instruct" print("Loading model...") model = Qwen3OmniMoeForConditionalGeneration.from_pretrained( MODEL_PATH, trust_remote_code=True, dtype=torch.float16, device_map="auto", ) model.disable_talker() processor = Qwen3OmniMoeProcessor.from_pretrained(MODEL_PATH) print("✅ Model & processor loaded.") # ========================= # 初始化 FastAPI # ========================= app = FastAPI(title="Qwen3-Omni vllm-format wrapper") # 在模型加载成功后定义健康检查接口 @app.get("/health") def health_check(): return JSONResponse(status_code=200, content={"status": "ok"}) # ========================= # 请求 schema # ========================= class GenerateRequest(BaseModel): messages: List[Dict[str, Any]] speaker: Optional[str] = "Ethan" use_audio_in_video: Optional[bool] = True resize_image: Optional[bool] = True image_size: Optional[int] = 448 # ========================= # 工具函数 # ========================= def _decode_data_uri_image(data_uri: str) -> Image.Image: if not data_uri.startswith("data:"): raise ValueError("Not a data URI") header, b64 = data_uri.split(",", 1) decoded = base64.b64decode(b64) return Image.open(io.BytesIO(decoded)) # ========================= # 主推理接口 # ========================= @app.post("/generate") def generate(req: GenerateRequest): messages = req.messages if not messages or not isinstance(messages, list): raise HTTPException(status_code=400, detail="messages must be a non-empty list") conversation = [] for m in messages: role = m.get("role", "user") raw_content = m.get("content", []) content_list = [] for c in raw_content: ctype = c.get("type") if ctype == "image_url": url = c.get("image_url", {}).get("url") if isinstance(c.get("image_url"), dict) else c.get("image_url") content_list.append({"type": "image", "image": url}) elif ctype == "text": content_list.append({"type": "text", "text": c.get("text", "")}) else: content_list.append(c) conversation.append({"role": role, "content": content_list}) text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) images = [] for msg in conversation: for c in msg["content"]: if c.get("type") == "image": img_ref = c.get("image") if isinstance(img_ref, str) and img_ref.startswith("data:"): try: pil = _decode_data_uri_image(img_ref) pil = pil.resize((224, 224)) images.append(pil) except Exception as e: raise HTTPException(status_code=400, detail=f"failed to decode data URI image: {e}") inputs = processor( text=text, images=images, return_tensors="pt", padding=True, use_audio_in_video=False, ).to(model.device, dtype=torch.float16) try: text_ids, audio_out = model.generate( **inputs, speaker=req.speaker, thinker_return_dict_in_generate=True, use_audio_in_video=False, max_new_tokens=128, ) except Exception as e: raise HTTPException(status_code=500, detail=f"model.generate failed: {e}") try: out_texts = processor.batch_decode( text_ids.sequences[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True, clean_up_tokenization_spaces=False, ) except Exception as e: raise HTTPException(status_code=500, detail=f"processor.batch_decode failed: {e}") result = { "id": "chatcmpl", "object": "chat.completion", "created": int(time.time()), "model": "Qwen3-Omni", "choices": [ { "index": 0, "message": {"role": "assistant", "content": out_texts}, "finish_reason": "stop" } ] } return JSONResponse(result)