init
This commit is contained in:
143
app.py
Normal file
143
app.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import os
|
||||
import io
|
||||
import base64
|
||||
from typing import Any, Dict, List, Optional
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from PIL import Image
|
||||
from starlette.responses import JSONResponse
|
||||
from transformers import Qwen3OmniMoeForConditionalGeneration, Qwen3OmniMoeProcessor
|
||||
from qwen_omni_utils import process_mm_info
|
||||
import time
|
||||
|
||||
os.environ["PYTORCH_MLU_ALLOC_CONF"] = "expandable_segments:True"
|
||||
|
||||
MODEL_PATH = "/mnt/models/Qwen3-Omni-30B-A3B-Instruct"
|
||||
|
||||
print("Loading model...")
|
||||
model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
|
||||
MODEL_PATH,
|
||||
trust_remote_code=True,
|
||||
dtype=torch.float16,
|
||||
device_map="auto",
|
||||
)
|
||||
model.disable_talker()
|
||||
processor = Qwen3OmniMoeProcessor.from_pretrained(MODEL_PATH)
|
||||
print("✅ Model & processor loaded.")
|
||||
|
||||
# =========================
|
||||
# 初始化 FastAPI
|
||||
# =========================
|
||||
app = FastAPI(title="Qwen3-Omni vllm-format wrapper")
|
||||
|
||||
# 在模型加载成功后定义健康检查接口
|
||||
@app.get("/health")
|
||||
def health_check():
|
||||
return JSONResponse(status_code=200, content={"status": "ok"})
|
||||
|
||||
# =========================
|
||||
# 请求 schema
|
||||
# =========================
|
||||
class GenerateRequest(BaseModel):
|
||||
messages: List[Dict[str, Any]]
|
||||
speaker: Optional[str] = "Ethan"
|
||||
use_audio_in_video: Optional[bool] = True
|
||||
resize_image: Optional[bool] = True
|
||||
image_size: Optional[int] = 448
|
||||
|
||||
# =========================
|
||||
# 工具函数
|
||||
# =========================
|
||||
def _decode_data_uri_image(data_uri: str) -> Image.Image:
|
||||
if not data_uri.startswith("data:"):
|
||||
raise ValueError("Not a data URI")
|
||||
header, b64 = data_uri.split(",", 1)
|
||||
decoded = base64.b64decode(b64)
|
||||
return Image.open(io.BytesIO(decoded))
|
||||
|
||||
# =========================
|
||||
# 主推理接口
|
||||
# =========================
|
||||
@app.post("/generate")
|
||||
def generate(req: GenerateRequest):
|
||||
messages = req.messages
|
||||
if not messages or not isinstance(messages, list):
|
||||
raise HTTPException(status_code=400, detail="messages must be a non-empty list")
|
||||
|
||||
conversation = []
|
||||
for m in messages:
|
||||
role = m.get("role", "user")
|
||||
raw_content = m.get("content", [])
|
||||
content_list = []
|
||||
for c in raw_content:
|
||||
ctype = c.get("type")
|
||||
if ctype == "image_url":
|
||||
url = c.get("image_url", {}).get("url") if isinstance(c.get("image_url"), dict) else c.get("image_url")
|
||||
content_list.append({"type": "image", "image": url})
|
||||
elif ctype == "text":
|
||||
content_list.append({"type": "text", "text": c.get("text", "")})
|
||||
else:
|
||||
content_list.append(c)
|
||||
conversation.append({"role": role, "content": content_list})
|
||||
|
||||
text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
|
||||
|
||||
images = []
|
||||
for msg in conversation:
|
||||
for c in msg["content"]:
|
||||
if c.get("type") == "image":
|
||||
img_ref = c.get("image")
|
||||
if isinstance(img_ref, str) and img_ref.startswith("data:"):
|
||||
try:
|
||||
pil = _decode_data_uri_image(img_ref)
|
||||
pil = pil.resize((224, 224))
|
||||
images.append(pil)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"failed to decode data URI image: {e}")
|
||||
|
||||
inputs = processor(
|
||||
text=text,
|
||||
images=images,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
use_audio_in_video=False,
|
||||
).to(model.device, dtype=torch.float16)
|
||||
|
||||
try:
|
||||
text_ids, audio_out = model.generate(
|
||||
**inputs,
|
||||
speaker=req.speaker,
|
||||
thinker_return_dict_in_generate=True,
|
||||
use_audio_in_video=False,
|
||||
max_new_tokens=128,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"model.generate failed: {e}")
|
||||
|
||||
try:
|
||||
out_texts = processor.batch_decode(
|
||||
text_ids.sequences[:, inputs["input_ids"].shape[1]:],
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"processor.batch_decode failed: {e}")
|
||||
|
||||
result = {
|
||||
"id": "chatcmpl",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": "Qwen3-Omni",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": out_texts},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
return JSONResponse(result)
|
||||
|
||||
Reference in New Issue
Block a user