[Test] Add initial multi modal cases of Qwen2.5-VL-7B-Instruct for disaggregated encoder (#5301)

### What this PR does / why we need it?
This PR adds disaggregated encoder  tests for Qwen2.5-VL-7B-Instruct 
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
by running the test
by running ci

- vLLM version: release/v0.12.0

---------

Signed-off-by: wangyu31577 <wangyu31577@hundsun.com>
Signed-off-by: wangyu <53896905+yenuo26@users.noreply.github.com>
Co-authored-by: wangyu31577 <wangyu31577@hundsun.com>
This commit is contained in:
wangyu
2026-02-06 17:30:17 +08:00
committed by GitHub
parent 06c0aed124
commit c63b7a1188
8 changed files with 1361 additions and 1 deletions

View File

@@ -126,6 +126,9 @@ jobs:
- name: qwen2-5-vl-7b
os: linux-aarch64-a3-4
tests: tests/e2e/nightly/single_node/models/test_qwen2_5_vl_7b.py
- name: qwen2-5-vl-7b-epd
os: linux-aarch64-a3-4
tests: tests/e2e/nightly/single_node/models/test_qwen2_5_vl_7b_epd.py
- name: qwen2-5-vl-32b
os: linux-aarch64-a3-4
tests: tests/e2e/nightly/single_node/models/test_qwen2_5_vl_32b.py

View File

@@ -126,6 +126,8 @@ e2e-multicard-2-cards:
estimated_time: 1050
- name: tests/e2e/multicard/2-cards/test_single_request_aclgraph.py
estimated_time: 215
- name: tests/e2e/multicard/2-cards/test_disaggregated_encoder.py
estimated_time: 90
e2e-multicard-4-cards:
# TODO: recover skipped tests

View File

@@ -0,0 +1,206 @@
#!/bin/bash
set -euo pipefail
declare -a PIDS=()
###############################################################################
# Configuration -- override via env before running
###############################################################################
MODEL="${MODEL:-Qwen/Qwen2.5-VL-7B-Instruct}"
LOG_PATH="${LOG_PATH:-./logs}"
mkdir -p $LOG_PATH
ENCODE_PORT="${ENCODE_PORT:-19534}"
PREFILL_DECODE_PORT="${PREFILL_DECODE_PORT:-19535}"
PROXY_PORT="${PROXY_PORT:-10001}"
CARD_E="${CARD_E:-0}"
CARD_PD="${CARD_PD:-1}"
EC_SHARED_STORAGE_PATH="${EC_SHARED_STORAGE_PATH:-/tmp/ec_cache}"
TIMEOUT_SECONDS="${TIMEOUT_SECONDS:-12000}" # wait_for_server timeout
NUM_PROMPTS="${NUM_PROMPTS:-100}" # number of prompts to send in benchmark
###############################################################################
# Helpers
###############################################################################
# Find the git repository root directory
VLLM_ROOT="/vllm-workspace/vllm"
START_TIME=$(date +"%Y%m%d_%H%M%S")
ENC_LOG=$LOG_PATH/encoder_${START_TIME}.log
PD_LOG=$LOG_PATH/pd_${START_TIME}.log
PROXY_LOG=$LOG_PATH/proxy_${START_TIME}.log
wait_for_server() {
local port=$1
timeout "$TIMEOUT_SECONDS" bash -c "
until curl -s localhost:$port/v1/chat/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
# Cleanup function
cleanup() {
echo "Stopping everything…"
trap - INT TERM USR1 # prevent re-entrancy
# Kill all tracked PIDs
for pid in "${PIDS[@]}"; do
if kill -0 "$pid" 2>/dev/null; then
echo "Killing process $pid"
kill "$pid" 2>/dev/null
fi
done
# Wait a moment for graceful shutdown
sleep 2
# Force kill any remaining processes
for pid in "${PIDS[@]}"; do
if kill -0 "$pid" 2>/dev/null; then
echo "Force killing process $pid"
kill -9 "$pid" 2>/dev/null
fi
done
# Kill the entire process group as backup
kill -- -$$ 2>/dev/null
echo "All processes stopped."
exit 0
}
trap cleanup INT
trap cleanup USR1
trap cleanup TERM
# clear previous cache
echo "remove previous ec cache folder"
rm -rf $EC_SHARED_STORAGE_PATH
echo "make ec cache folder"
mkdir -p $EC_SHARED_STORAGE_PATH
###############################################################################
# Encoder worker
###############################################################################
ASCEND_RT_VISIBLE_DEVICES="$CARD_E" vllm serve "$MODEL" \
--gpu-memory-utilization 0.01 \
--port "$ENCODE_PORT" \
--enforce-eager \
--enable-request-id-headers \
--no-enable-prefix-caching \
--max-num-batched-tokens 114688 \
--max-num-seqs 128 \
--ec-transfer-config '{
"ec_connector": "ECExampleConnector",
"ec_role": "ec_producer",
"ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
}
}' \
>"${ENC_LOG}" 2>&1 &
PIDS+=($!)
###############################################################################
# Prefill+Decode worker
###############################################################################
ASCEND_RT_VISIBLE_DEVICES="$CARD_PD" vllm serve "$MODEL" \
--gpu-memory-utilization 0.9 \
--port "$PREFILL_DECODE_PORT" \
--enforce-eager \
--enable-request-id-headers \
--max-num-seqs 128 \
--ec-transfer-config '{
"ec_connector": "ECExampleConnector",
"ec_role": "ec_consumer",
"ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
}
}' \
>"${PD_LOG}" 2>&1 &
PIDS+=($!)
# Wait for workers
wait_for_server $ENCODE_PORT
wait_for_server $PREFILL_DECODE_PORT
###############################################################################
# Proxy
###############################################################################
python ./disagg_epd_proxy.py \
--host "0.0.0.0" \
--port "$PROXY_PORT" \
--encode-servers-urls "http://localhost:$ENCODE_PORT" \
--prefill-servers-urls "disable" \
--decode-servers-urls "http://localhost:$PREFILL_DECODE_PORT" \
>"${PROXY_LOG}" 2>&1 &
PIDS+=($!)
wait_for_server $PROXY_PORT
echo "All services are up!"
###############################################################################
# Single request with local image
###############################################################################
echo "Running single request with local image (non-stream)..."
echo "Running single request with local image (non-stream)..."
base64_image=$(base64 -w 0 "${VLLM_ROOT}/tests/v1/ec_connector/integration/hato.jpg")
cat > /tmp/request.json << EOF
{
"model": "${MODEL}",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "data:image/jpg;base64,${base64_image}"
}
},
{
"type": "text",
"text": "What is in this image?"
}
]
}
]
}
EOF
curl http://127.0.0.1:${PROXY_PORT}/v1/chat/completions \
-H "Content-Type: application/json" \
-d @/tmp/request.json
rm -f /tmp/request.json
###############################################################################
# Benchmark
###############################################################################
echo "Running benchmark (stream)..."
vllm bench serve \
--model $MODEL \
--backend openai-chat \
--endpoint /v1/chat/completions \
--dataset-name random-mm \
--seed 0 \
--num-prompts $NUM_PROMPTS \
--port $PROXY_PORT
PIDS+=($!)
# cleanup
echo "cleanup..."
cleanup

View File

@@ -0,0 +1,749 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
disagg_encoder_proxy.py
Proxy that routes OpenAI-compatible “/v1/chat/completions” requests to two
clusters:
• encode (multimodal feature extraction)
• decode (language-model inference)
For MM input we:
1. Extract *every* image/audio item.
2. Fire N concurrent requests to the encoder cluster
(one request per item, with **all text removed**).
3. Wait for all of them to succeed.
4. Forward the *original* request to a decode server.
"""
from __future__ import annotations
import argparse
import asyncio
import copy
import logging
import os
import random
import uuid
from collections.abc import AsyncIterator
from enum import Enum
import aiohttp
import uvicorn
from aiohttp import ClientResponse
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, StreamingResponse
###############################################################################
# FastAPI app & global state
###############################################################################
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(levelname)s: %(message)s")
logger = logging.getLogger("proxy")
app = FastAPI()
encode_session: aiohttp.ClientSession | None = None
prefill_session: aiohttp.ClientSession | None = None
decode_session: aiohttp.ClientSession | None = None
###############################################################################
# Utils
###############################################################################
MM_TYPES = {"image_url", "audio_url", "input_audio"}
class EncoderDispatchMode(str, Enum):
SINGLE = "single"
FANOUT = "fanout"
def extract_mm_items(request_data: dict) -> list[dict]:
"""
Return *all* image/audio items that appear anywhere in `messages`.
Each returned dict looks like:
{ "type": "image_url", "image_url": {...} }
"""
items: list[dict] = []
for msg in request_data.get("messages", []):
content = msg.get("content")
if not isinstance(content, list):
continue
for item in content:
if item.get("type") in MM_TYPES:
items.append(item)
return items
async def _encode_fanout(
orig_request: dict,
e_urls: list[str],
req_id: str,
):
logger.info("[%s] Processing multimodal items...", req_id)
mm_items = extract_mm_items(orig_request)
if not mm_items:
logger.info("[%s] No multimodal items, skipping encoder", req_id)
return # nothing to do
logger.info("[%s] got %d multimodal items...", req_id, len(mm_items))
tasks = []
# Round-robin over encode servers to distribute load a bit
url_cycle = (e_urls[i % len(e_urls)] for i in range(len(mm_items)))
for idx, (item, target_url) in enumerate(zip(mm_items, url_cycle)):
# Derive a *child* request id: <parent>:<index>:<random-short>
child_req_id = f"{req_id}:{idx}:{uuid.uuid4().hex[:6]}"
headers = {"x-request-id": child_req_id}
encoder_req = {
# You *may* need to keep additional fields
"model": orig_request.get("model"),
"messages": [
{"role": "user", "content": [item]},
],
# Only need 1 token so the server actually runs the encoder path
"max_tokens": 1,
"stream": False,
}
if encode_session is None:
raise HTTPException(status_code=500, detail="Encode session not initialized")
tasks.append(
encode_session.post(
f"{target_url}/v1/chat/completions",
json=encoder_req,
headers=headers,
)
)
results = await asyncio.gather(*tasks, return_exceptions=True)
# Fail fast if any sub-request failed
for idx, r in enumerate(results):
if isinstance(r, Exception):
logger.error(
"[%s] Encoder request #%d raised exception: %s",
req_id,
idx,
r,
exc_info=r,
)
error_detail = str(r)
if hasattr(r, "status"):
error_detail = f"Status: {r.status}, Error: {error_detail}"
elif hasattr(r, "status_code"):
error_detail = f"Status: {r.status_code}, Error: {error_detail}"
raise HTTPException(status_code=502, detail=f"Encoder request failed: {error_detail}")
if isinstance(r, ClientResponse):
if hasattr(r, "status") and r.status != 200:
try:
detail = await r.text()
except Exception:
detail = "<unable to read body>"
logger.error(
"[%s] Encoder request #%d returned status %s: %s",
req_id,
idx,
r.status,
detail,
)
raise HTTPException(
status_code=r.status,
detail=f"Encoder request failed: {detail}",
)
logger.info("[%s] All %d encoder requests completed successfully", req_id, len(mm_items))
async def _encode_single_request(
orig_request: dict,
e_url: str,
req_id: str,
) -> None:
"""
1. Build one request *per MM item* with all text removed.
2. Send them concurrently to the encode cluster.
3. Raise if any of them fails.
"""
logger.info("[%s] Processing multimodal items...", req_id)
request_data = copy.deepcopy(orig_request)
headers = {"x-request-id": req_id}
request_data["max_tokens"] = 1
request_data["stream"] = False
request_data.pop("stream_options", None)
if "max_completion_tokens" in request_data:
request_data["max_completion_tokens"] = 1
try:
if encode_session is None:
raise HTTPException(status_code=500, detail="Encode session not initialized")
encode_response = await encode_session.post(f"{e_url}/v1/chat/completions", json=request_data, headers=headers)
encode_response.raise_for_status()
if encode_response.status != 200:
encode_text = await encode_response.text()
raise HTTPException(
status_code=encode_response.status,
detail={"error": "Encoder request failed", "message": encode_text},
)
logger.debug("Encoder processing completed successfully for req_id: %s", req_id)
return encode_response
except Exception as e:
logger.error("Encoder processing failed: %s", str(e))
raise HTTPException(
status_code=500,
detail={"error": "Encoder processing error", "message": str(e)},
) from e
logger.info("[%s] Encoder request completed successfully", req_id)
async def fanout_encoder_primer(
orig_request: dict,
req_id: str,
):
mode = app.state.encoder_dispatch_mode
if mode == EncoderDispatchMode.SINGLE:
e_url = random.choice(app.state.e_urls)
await _encode_single_request(orig_request, e_url, req_id)
elif mode == EncoderDispatchMode.FANOUT:
await _encode_fanout(orig_request, app.state.e_urls, req_id)
else:
raise RuntimeError(f"Unknown encoder dispatch mode: {mode}")
async def maybe_prefill(
req_data: dict,
p_url: str,
req_id: str,
) -> dict:
"""
- Do prefill-only task if p_url exist;
- Return modified request data with kv transfer params (for nixl connector)
- Else, skip and return the original request data for decode
"""
if p_url:
logger.info("[%s] Processing through prefill: %s", req_id, p_url)
prefill_response = await process_prefill_stage(req_data, p_url, req_id)
if isinstance(prefill_response, ClientResponse):
# for nixl connector to facilitate kv transfer...
prefill_response_json = await prefill_response.json()
kv_transfer_params = prefill_response_json.get("kv_transfer_params", {})
if kv_transfer_params:
req_data["kv_transfer_params"] = kv_transfer_params
return req_data
else:
return req_data
async def process_prefill_stage(
req_data: dict,
p_url: str,
req_id: str,
) -> ClientResponse:
"""Process request through Prefill stage and return kv_transfer_params"""
logger.info("[%s] Sending prefill request to: %s", req_id, p_url)
prefill_request = req_data.copy()
prefill_request["kv_transfer_params"] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": None,
"remote_port": None,
}
prefill_request["stream"] = False
prefill_request["max_tokens"] = 1
if "max_completion_tokens" in prefill_request:
prefill_request["max_completion_tokens"] = 1
if "stream_options" in prefill_request:
del prefill_request["stream_options"]
headers = {"x-request-id": req_id}
try:
if prefill_session is None:
raise HTTPException(status_code=500, detail="Prefill session not initialized")
prefill_response = await prefill_session.post(
f"{p_url}/v1/chat/completions", json=prefill_request, headers=headers
)
prefill_response.raise_for_status()
if prefill_response.status != 200:
error_text = await prefill_response.text()
logger.error(
"[%s] Prefill request failed with status %d: %s",
req_id,
prefill_response.status,
error_text,
)
raise HTTPException(
status_code=prefill_response.status,
detail={"error": "Prefill request failed", "message": error_text},
)
logger.info("[%s] Prefill request completed successfully", req_id)
return prefill_response
except Exception as e:
logger.error("Prefill processing failed: %s", str(e))
raise HTTPException(
status_code=500,
detail={"error": "Prefill processing error", "message": str(e)},
) from e
def has_mm_input(request_data: dict):
if "messages" not in request_data:
return False
for message in request_data["messages"]:
if not isinstance(message.get("content"), list):
continue
for content_item in message["content"]:
if content_item.get("type") in ["image_url", "audio_url", "input_audio"]:
return True
return False
###############################################################################
# Middleware for request/response logging
###############################################################################
@app.middleware("http")
async def log_requests(request: Request, call_next):
"""Middleware to log all incoming requests and responses"""
req_id = request.headers.get("x-request-id", str(uuid.uuid4()))
# Log incoming request
logger.info(
">>> [%s] %s %s from %s",
req_id,
request.method,
request.url.path,
request.client.host if request.client else "unknown",
)
try:
# Process request
response = await call_next(request)
# Log response
logger.info(
"<<< [%s] %s %s completed with status %d",
req_id,
request.method,
request.url.path,
response.status_code,
)
return response
except Exception as e:
# Log errors
logger.exception(
"!!! [%s] %s %s failed with error: %s",
req_id,
request.method,
request.url.path,
str(e),
)
raise
###############################################################################
# FastAPI lifecycle
###############################################################################
@app.on_event("startup")
async def on_startup() -> None:
global encode_session, prefill_session, decode_session
timeout = aiohttp.ClientTimeout(total=100_000)
connector = aiohttp.TCPConnector(limit=0, force_close=False, keepalive_timeout=0)
encode_session = aiohttp.ClientSession(timeout=timeout, connector=connector)
if app.state.p_urls:
# only setup if prefill instance(s) exist
prefill_session = aiohttp.ClientSession(timeout=timeout, connector=connector)
decode_session = aiohttp.ClientSession(timeout=timeout, connector=connector)
@app.on_event("shutdown")
async def on_shutdown() -> None:
global encode_session, prefill_session, decode_session
if encode_session:
await encode_session.close()
if prefill_session:
await prefill_session.close()
if decode_session:
await decode_session.close()
###############################################################################
# Core forwarding
###############################################################################
async def forward_non_stream(req_data: dict, req_id: str, p_url: str, d_url: str) -> dict:
try:
# Step 1: Process through Encoder instance (if has MM input)
async def run_encoder():
await fanout_encoder_primer(req_data, req_id)
if has_mm_input(req_data):
await non_stream_retry_wrap(run_encoder)
# Step 2: Process through Prefill instance
async def run_prefill():
return await maybe_prefill(req_data, p_url, req_id)
req_data = await non_stream_retry_wrap(run_prefill)
async def run_decode_non_stream():
# Step 3: Process through Decode instance
logger.info("[%s] Forwarding to decode: %s", req_id, d_url)
headers = {"x-request-id": req_id}
# Non-streaming response
if decode_session is None:
raise HTTPException(status_code=500, detail="Decode session not initialized")
async with decode_session.post(f"{d_url}/v1/chat/completions", json=req_data, headers=headers) as resp:
resp.raise_for_status()
return await resp.json()
return await non_stream_retry_wrap(run_decode_non_stream)
except HTTPException:
raise
except Exception as e:
logger.exception("[%s] Error in forward_non_stream: %s", req_id, str(e))
raise HTTPException(status_code=500, detail=f"Proxy error: {str(e)}") from e
async def stream_retry_wrap(forward_func, max_retries: int = 3, delay: float = 0.001):
last_exc = None
first_chunk_sent = False
for attempt in range(max_retries):
try:
async for chunk in forward_func():
first_chunk_sent = True
yield chunk
return
except Exception as e:
if first_chunk_sent:
raise
if isinstance(e, HTTPException) and e.status_code < 500:
raise
last_exc = e
logger.warning(
"attempt %s / %s failed retrying... ",
attempt + 1,
max_retries,
)
await asyncio.sleep(delay * (attempt + 1))
raise RuntimeError(f"all {max_retries} retries failed.") from last_exc
async def non_stream_retry_wrap(forward_func, max_retries: int = 3, delay: float = 0.001):
last_exc = None
for attempt in range(max_retries):
try:
result = await forward_func()
return result
except Exception as e:
if isinstance(e, HTTPException) and e.status_code < 500:
raise
last_exc = e
logger.warning(
"attempt %s / %s failed retrying... ",
attempt + 1,
max_retries,
)
await asyncio.sleep(delay * (attempt + 1))
raise RuntimeError(f"all {max_retries} retries failed.") from last_exc
async def forward_stream(req_data: dict, req_id: str, p_url: str, d_url: str) -> AsyncIterator[str]:
try:
# Step 1: Process through Encoder instance (if has MM input)
async def run_encoder():
await fanout_encoder_primer(req_data, req_id)
if has_mm_input(req_data):
await non_stream_retry_wrap(run_encoder)
# Step 2: Process through Prefill instance
async def run_prefill():
return await maybe_prefill(req_data, p_url, req_id)
req_data = await non_stream_retry_wrap(run_prefill)
async def run_decode_stream():
# Step 3: Process through Decode instance
logger.info("[%s] Starting streaming from decode: %s", req_id, d_url)
headers = {"x-request-id": req_id}
# Streaming response
if decode_session is None:
raise HTTPException(status_code=500, detail="Decode session not initialized")
async with decode_session.post(
f"{d_url}/v1/chat/completions",
json=req_data,
headers=headers,
) as resp:
resp.raise_for_status()
async for chunk in resp.content.iter_chunked(1024):
if chunk:
yield chunk.decode("utf-8", errors="ignore")
logger.info("[%s] Streaming completed", req_id)
async for chunk in stream_retry_wrap(run_decode_stream):
yield chunk
except HTTPException:
logger.exception("[%s] HTTPException in forward_stream", req_id)
raise
except Exception as e:
logger.exception("[%s] Error in forward_stream: %s", req_id, str(e))
raise HTTPException(status_code=500, detail=f"Proxy streaming error: {str(e)}") from e
###############################################################################
# Public routes
###############################################################################
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
try:
req_data = await request.json()
req_id = request.headers.get("x-request-id", str(uuid.uuid4()))
p_url = random.choice(app.state.p_urls) if app.state.p_urls else None
d_url = random.choice(app.state.d_urls)
is_streaming = req_data.get("stream", False)
if is_streaming:
return StreamingResponse(
forward_stream(req_data, req_id, p_url, d_url),
media_type="text/event-stream",
)
result = await forward_non_stream(req_data, req_id, p_url, d_url)
return JSONResponse(content=result)
except HTTPException:
raise
except Exception as e:
logger.exception("Error in chat_completions endpoint: %s", str(e))
raise HTTPException(status_code=500, detail=f"Request processing error: {str(e)}") from e
@app.get("/v1/models")
async def list_models():
if decode_session is None:
raise HTTPException(status_code=500, detail="Decode session not initialized")
async with decode_session.get(f"{app.state.d_urls[0]}/v1/models") as resp:
resp.raise_for_status()
return await resp.json()
@app.get("/health")
async def health_check():
async def healthy(urls, session):
if not urls:
return "empty"
for u in urls:
try:
if session is None:
return "unhealthy"
async with session.get(f"{u}/health") as resp:
resp.raise_for_status()
except Exception:
return "unhealthy"
return "healthy"
e_status, p_status, d_status = await asyncio.gather(
healthy(app.state.e_urls, encode_session),
healthy(app.state.p_urls, prefill_session),
healthy(app.state.d_urls, decode_session),
)
overall_healthy = all(status != "unhealthy" for status in (e_status, p_status, d_status))
status_code = 200 if overall_healthy else 503
return JSONResponse(
{
"proxy": "healthy",
"encode_cluster": e_status,
"prefill_cluster": p_status,
"decode_cluster": d_status,
},
status_code=status_code,
)
###############################################################################
# Simple profiler fan-out (unchanged except for sessions)
###############################################################################
async def _post_if_available(
session: aiohttp.ClientSession,
url: str,
payload: dict,
headers: dict,
) -> dict | None:
"""
POST `payload` to `url`.
Returns
-------
• The decoded JSON body on success (2xx)
• None if the endpoint does not exist (404)
• Raises for anything else.
"""
try:
if session is None:
return None
resp = await session.post(url, json=payload, headers=headers)
if resp.status == 404: # profiling disabled on that server
logger.warning("Profiling endpoint missing on %s", url)
return None
resp.raise_for_status()
return await resp.json(content_type=None)
except aiohttp.ClientResponseError as exc:
# Pass 404 through the branch above, re-raise everything else
if exc.status == 404:
logger.warning("Profiling endpoint missing on %s", url)
return None
raise
except Exception:
# Network errors etc.: propagate
raise
async def _profile_cmd(cmd: str, payload: dict, e_url: str, p_url: str, d_url: str):
"""
Fire & forget to both clusters, tolerate 404.
"""
headers = {"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY', '')}"}
encode_task = _post_if_available(encode_session, f"{e_url}/{cmd}_profile", payload, headers)
prefill_task = (
_post_if_available(prefill_session, f"{p_url}/{cmd}_profile", payload, headers)
if p_url is not None
else asyncio.sleep(0)
)
decode_task = _post_if_available(decode_session, f"{d_url}/{cmd}_profile", payload, headers)
encode_res, prefill_res, decode_res = await asyncio.gather(encode_task, prefill_task, decode_task)
# If *all* clusters said “I dont have that route”, surface an error
if encode_res is prefill_res is decode_res is None:
raise HTTPException(
status_code=503,
detail="Profiling endpoints are disabled on all clusters",
)
return {
"encode": encode_res, # may be None
"prefill": prefill_res, # may be None
"decode": decode_res, # may be None
}
@app.post("/start_profile")
async def start_profile(request: Request):
body = await request.json()
# TODO: handle multi urls properly
e_url = random.choice(app.state.e_urls)
p_url = random.choice(app.state.p_urls) if app.state.p_urls else None
d_url = random.choice(app.state.d_urls)
return await _profile_cmd("start", body, e_url, p_url, d_url)
@app.post("/stop_profile")
async def stop_profile(request: Request):
body = await request.json()
# TODO: handle multi urls properly
e_url = random.choice(app.state.e_urls)
p_url = random.choice(app.state.p_urls) if app.state.p_urls else None
d_url = random.choice(app.state.d_urls)
return await _profile_cmd("stop", body, e_url, p_url, d_url)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument(
"--encode-servers-urls",
required=True,
help='Comma-separated encode URLs ("http://e1:8001,http://e2:8001")',
)
parser.add_argument(
"--prefill-servers-urls",
required=True,
help='Comma-separated prefill URLs ("http://p1:8003,http://p2:8004") to enable E->P->D, '
'set "disable" or "none" to enable E->PD',
)
parser.add_argument(
"--decode-servers-urls",
required=True,
help='Comma-separated decode URLs ("http://d1:8005,http://d2:8006")',
)
parser.add_argument(
"--encoder-dispatch-mode",
choices=["single", "fanout"],
default="single",
help="Encoder dispatch mode: single (one request) or fanout (per-MM-item)",
)
args = parser.parse_args()
app.state.e_urls = [u.strip() for u in args.encode_servers_urls.split(",") if u.strip()]
app.state.d_urls = [u.strip() for u in args.decode_servers_urls.split(",") if u.strip()]
# handle prefill instances
if args.prefill_servers_urls.lower() in ("disable", "none", ""):
app.state.p_urls = []
logger.info("Disaggregated prefill phase explicitly disabled by user. Running E + PD...")
else:
app.state.p_urls = [u.strip() for u in args.prefill_servers_urls.split(",") if u.strip()]
logger.info("Disaggregated prefill phase is enabled. Running E + P + D...")
app.state.encoder_dispatch_mode = EncoderDispatchMode(args.encoder_dispatch_mode)
logger.info("Proxy listening on %s:%s", args.host, args.port)
logger.info("Encode servers: %s", app.state.e_urls)
logger.info("Prefill instances %s", app.state.p_urls)
logger.info("Decode servers: %s", app.state.d_urls)
uvicorn.run(
app,
host=args.host,
port=args.port,
log_level="info",
loop="uvloop",
access_log=True,
)

View File

@@ -18,6 +18,7 @@
#
import contextlib
import copy
import functools
import gc
import json
@@ -27,11 +28,15 @@ import os
import shlex
import subprocess
import sys
import threading
import time
import traceback
from pathlib import Path
from typing import Any, Optional, Tuple, TypeVar, Union
import numpy as np
import openai
import psutil
import pytest
import requests
import torch
@@ -80,6 +85,10 @@ logger = logging.getLogger(__name__)
_TEST_DIR = os.path.dirname(__file__)
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "long_prompt.txt")]
DISAGG_EPD_PROXY_SCRIPT = Path(
__file__
).parent.parent.parent / "examples" / "disaggregated_encoder" / "disagg_epd_proxy.py"
def _check_npu_memory_worker(target_free_percentage: float, max_wait_seconds: float):
import torch_npu # type: ignore
@@ -441,6 +450,216 @@ class RemoteOpenAIServer:
**kwargs)
class RemoteEPDServer(RemoteOpenAIServer):
def _start_server(self, model: str, server_cmd: list[str],
env_dict: Optional[dict[str, str]]) -> None:
"""Subclasses override this method to customize server process launch
"""
raise NotImplementedError("RemoteEPDServer should use _start_server_with_prefix instead")
def __init__(self,
vllm_serve_args: Union[list[str], list[list[str]]],
server_host: str = '0.0.0.0',
env_dict: Optional[dict[str, str]] = None,
max_wait_seconds: Optional[float] = 2800) -> None:
self._proc_list = []
self.env_dict: dict[str, str] = {}
if env_dict is not None:
self.env_dict.update(env_dict)
self.env_dict['VLLM_ALLOW_LONG_MAX_MODEL_LEN'] = "1"
self.env_dict['VLLM_USE_V1'] = "1"
self.env_dict['PYTORCH_NPU_ALLOC_CONF'] = "expandable_segments:True"
self.env_dict['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
self.vllm_serve_args_list = []
self.health_url_list = []
self.host = server_host
if isinstance(vllm_serve_args, list):
if not all(isinstance(item, list) for item in vllm_serve_args):
args_copy = copy.deepcopy(vllm_serve_args)
self.vllm_serve_args_list.append([str(arg) for arg in args_copy])
else:
self.vllm_serve_args_list = [
[str(arg) for arg in sublist]
for sublist in copy.deepcopy(vllm_serve_args)
]
else:
raise RuntimeError("vllm_serves_args must be a list")
serve_arg_cmd = ["vllm", "serve"]
for i, vllm_serve_arg in enumerate(self.vllm_serve_args_list):
self.env_dict['ASCEND_RT_VISIBLE_DEVICES'] = str(i)
if isinstance(vllm_serve_arg, list):
if "--port" not in vllm_serve_arg:
raise ValueError("You have manually specified the port ")
else:
port_arg = "--port"
try:
index = vllm_serve_arg.index(port_arg)
except ValueError:
raise ValueError(f"--port not found in args: {vllm_serve_arg}")
port_str = vllm_serve_arg[index + 1]
self.port = int(port_str)
else:
vllm_serve_arg_str = str(vllm_serve_arg)
if "--port" not in vllm_serve_arg_str:
raise ValueError("You have manually specified the port ")
else:
raise ValueError(f"Unexpected type for vllm_serve_arg: {type(vllm_serve_arg)}")
self.health_url_list.append(super().url_for("health"))
vllm_serve_arg = [*serve_arg_cmd, *vllm_serve_arg]
proc = self._start_server_with_prefix(vllm_serve_arg, self.env_dict,
f"[VLLM_{i}] ")
self._proc_list.append(proc)
timeout_value = float(max_wait_seconds) if max_wait_seconds is not None else 2800.0
super()._wait_for_multiple_servers([(self.host, url)
for url in self.health_url_list],
timeout=timeout_value)
def _poll(self) -> Optional[int]:
return None
def _delete_shm(self) -> None:
for i, arg in enumerate(self.vllm_serve_args_list):
if "--ec-transfer-config" in arg:
index = arg.index("--ec-transfer-config")
config_str = arg[index + 1]
config_dict = json.loads(config_str)
ec_connector_extra_config = config_dict.get("ec_connector_extra_config", {})
shm_path = ec_connector_extra_config.get("shared_storage_path")
if shm_path:
args = ["rm", "-r", "-f", str(shm_path)]
print(f"delete shm_path is: {shm_path}")
self._start_server_with_prefix(args, None, "[DELETE] ")
def _read_output(self, pipe, prefix):
try:
with pipe:
for line in iter(pipe.readline, ''):
if line:
print(f"{prefix}: {line}", end='')
except Exception as e:
print(f"error: {e}")
traceback.print_exc()
def _start_server_with_prefix(self, server_cmd: list[str],
env_dict: Optional[dict[str, str]], log_prefix: str):
env = os.environ.copy()
if env_dict is not None:
env.update(env_dict)
proc = subprocess.Popen(server_cmd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
bufsize=1)
stdout_thread = threading.Thread(target=self._read_output,
args=(proc.stdout, log_prefix),
daemon=True)
stderr_thread = threading.Thread(target=self._read_output,
args=(proc.stderr, log_prefix),
daemon=True)
stdout_thread.start()
stderr_thread.start()
return proc
def _terminate_server(self) -> None:
"""kill process and its children"""
print("vllm instance is stopping")
for proc in self._proc_list:
parent = psutil.Process(proc.pid)
children = parent.children(recursive=True)
for child in children:
try:
child.terminate()
except psutil.NoSuchProcess:
pass
gone, still_alive = psutil.wait_procs(children, timeout=10)
for child in still_alive:
try:
child.kill()
except psutil.NoSuchProcess:
pass
try:
parent.terminate()
parent.wait(timeout=10)
except (psutil.NoSuchProcess, psutil.TimeoutExpired):
try:
parent.kill()
except psutil.NoSuchProcess:
pass
def __enter__(self):
"""Context manager entry point."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit point - clean up all processes."""
self._terminate_server()
class DisaggEpdProxy(RemoteEPDServer):
def __init__(self,
proxy_args: Optional[Union[list[str], str]] = None,
env_dict: Optional[dict[str, str]] = None,
server_host: str = '0.0.0.0',
max_wait_seconds: Optional[float] = 2800) -> None:
if proxy_args is None:
proxy_args_list: list[str] = []
elif isinstance(proxy_args, str):
proxy_args_list = shlex.split(proxy_args)
else:
proxy_args_list = proxy_args
self.proxy_args = proxy_args_list
self.env_dict: dict[str, str] = {}
if env_dict is not None:
self.env_dict.update(env_dict)
self._proc_list = list()
self.host = server_host
print(f"proxy param is: {self.proxy_args}")
proxy_cmd = ["python", str(DISAGG_EPD_PROXY_SCRIPT), *self.proxy_args]
proc = self._start_server_with_prefix(proxy_cmd, self.env_dict, "[PROXY] ")
self._proc_list.append(proc)
if "--port" not in self.proxy_args:
raise ValueError("You have manually specified the port ")
else:
try:
index = self.proxy_args.index("--port")
except ValueError:
raise ValueError("--port not found in proxy args")
port_str = self.proxy_args[index + 1]
self.port = int(port_str)
timeout_value = float(max_wait_seconds) if max_wait_seconds is not None else 2800.0
super()._wait_for_multiple_servers(
[(self.host, super().url_for("health"))], timeout=timeout_value)
def __enter__(self):
"""Context manager entry point."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit point - clean up all processes."""
super()._terminate_server()
class VllmRunner:
def __init__(

View File

@@ -0,0 +1,71 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import pytest
from vllm.utils.network_utils import get_open_port
from tests.e2e.conftest import DisaggEpdProxy, RemoteEPDServer
from tools.send_mm_request import send_image_request
MODELS = [
"Qwen/Qwen2.5-VL-7B-Instruct",
]
SHARED_STORAGE_PATH = "/dev/shm/epd/storage"
TENSOR_PARALLELS = [1]
@pytest.mark.asyncio
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS)
async def test_models(model: str, tp_size: int) -> None:
encode_port = get_open_port()
pd_port = get_open_port()
vllm_server_args = [
[
"--port",
str(encode_port), "--model", model, "--gpu-memory-utilization",
"0.01", "--tensor-parallel-size",
str(tp_size), "--enforce-eager", "--no-enable-prefix-caching",
"--max-model-len", "10000", "--max-num-batched-tokens", "10000",
"--max-num-seqs", "1", "--ec-transfer-config",
'{"ec_connector_extra_config":{"shared_storage_path":"' +
SHARED_STORAGE_PATH +
'"},"ec_connector":"ECExampleConnector","ec_role": "ec_producer"}'
],
[
"--port",
str(pd_port), "--model", model, "--gpu-memory-utilization", "0.95",
"--tensor-parallel-size",
str(tp_size), "--enforce-eager", "--max-model-len", "10000",
"--max-num-batched-tokens", "10000", "--max-num-seqs", "128",
"--ec-transfer-config",
'{"ec_connector_extra_config":{"shared_storage_path":"' +
SHARED_STORAGE_PATH +
'"},"ec_connector":"ECExampleConnector","ec_role": "ec_consumer"}'
]
]
proxy_port = get_open_port()
proxy_args = [
"--host", "127.0.0.1", "--port",
str(proxy_port), "--encode-servers-urls",
f"http://localhost:{encode_port}", "--decode-servers-urls",
f"http://localhost:{pd_port}", "--prefill-servers-urls", "disable"
]
with RemoteEPDServer(vllm_serve_args=vllm_server_args) as _:
with DisaggEpdProxy(proxy_args=proxy_args) as proxy:
send_image_request(model, proxy)

View File

@@ -0,0 +1,110 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import pytest
from vllm.utils.network_utils import get_open_port
from tests.e2e.conftest import DisaggEpdProxy, RemoteEPDServer
from tools.aisbench import run_aisbench_cases
MODELS = [
"Qwen/Qwen2.5-VL-7B-Instruct",
]
SHARED_STORAGE_PATH = "/dev/shm/epd/storage"
TENSOR_PARALLELS = [1]
warmup_cases = [{
"case_type": "performance",
"dataset_path": "vllm-ascend/textvqa-perf-1080p",
"request_conf": "vllm_api_stream_chat",
"dataset_conf": "textvqa/textvqa_gen_base64",
"num_prompts": 50,
"max_out_len": 20,
"batch_size": 32,
"request_rate": 0,
"baseline": 1,
"threshold": 0.97
}]
aisbench_cases = [{
"case_type": "accuracy",
"dataset_path": "vllm-ascend/textvqa-lite",
"request_conf": "vllm_api_stream_chat",
"dataset_conf": "textvqa/textvqa_gen_base64",
"max_out_len": 2048,
"batch_size": 128,
"baseline": 82.05,
"threshold": 5
}, {
"case_type": "performance",
"dataset_path": "vllm-ascend/textvqa-perf-1080p",
"request_conf": "vllm_api_stream_chat",
"dataset_conf": "textvqa/textvqa_gen_base64",
"num_prompts": 512,
"max_out_len": 256,
"batch_size": 128,
"request_rate": 0,
"baseline": 1,
"threshold": 0.97
}]
@pytest.mark.asyncio
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS)
async def test_models(model: str, tp_size: int) -> None:
encode_port = get_open_port()
pd_port = get_open_port()
vllm_server_args = [
[
"--port",
str(encode_port), "--model", model, "--gpu-memory-utilization",
"0.01", "--tensor-parallel-size",
str(tp_size), "--enforce-eager", "--no-enable-prefix-caching",
"--max-model-len", "10000", "--max-num-batched-tokens", "10000",
"--max-num-seqs", "1", "--ec-transfer-config",
'{"ec_connector_extra_config":{"shared_storage_path":"' +
SHARED_STORAGE_PATH +
'"},"ec_connector":"ECExampleConnector","ec_role": "ec_producer"}'
],
[
"--port",
str(pd_port), "--model", model, "--gpu-memory-utilization", "0.95",
"--tensor-parallel-size",
str(tp_size), "--enforce-eager", "--max-model-len", "10000",
"--max-num-batched-tokens", "10000", "--max-num-seqs", "128",
"--ec-transfer-config",
'{"ec_connector_extra_config":{"shared_storage_path":"' +
SHARED_STORAGE_PATH +
'"},"ec_connector":"ECExampleConnector","ec_role": "ec_consumer"}'
]
]
proxy_port = get_open_port()
proxy_args = [
"--host", "127.0.0.1", "--port",
str(proxy_port), "--encode-servers-urls",
f"http://localhost:{encode_port}", "--decode-servers-urls",
f"http://localhost:{pd_port}", "--prefill-servers-urls", "disable"
]
with RemoteEPDServer(vllm_serve_args=vllm_server_args) as _:
with DisaggEpdProxy(proxy_args=proxy_args) as _:
# warm up
run_aisbench_cases(model=model,
port=proxy_port,
aisbench_cases=warmup_cases)
# aisbench test
run_aisbench_cases(model, proxy_port, aisbench_cases)

View File

@@ -67,7 +67,7 @@ def main():
print(f" - {pkg}")
sys.exit(1)
else:
print("All Python packages have __init__.py files.")
print("All Python packages have __init__.py files.")
if __name__ == "__main__":