[Test] Add initial multi modal cases of Qwen2.5-VL-7B-Instruct for disaggregated encoder (#5301)
### What this PR does / why we need it? This PR adds disaggregated encoder tests for Qwen2.5-VL-7B-Instruct ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? by running the test by running ci - vLLM version: release/v0.12.0 --------- Signed-off-by: wangyu31577 <wangyu31577@hundsun.com> Signed-off-by: wangyu <53896905+yenuo26@users.noreply.github.com> Co-authored-by: wangyu31577 <wangyu31577@hundsun.com>
This commit is contained in:
@@ -126,6 +126,9 @@ jobs:
|
||||
- name: qwen2-5-vl-7b
|
||||
os: linux-aarch64-a3-4
|
||||
tests: tests/e2e/nightly/single_node/models/test_qwen2_5_vl_7b.py
|
||||
- name: qwen2-5-vl-7b-epd
|
||||
os: linux-aarch64-a3-4
|
||||
tests: tests/e2e/nightly/single_node/models/test_qwen2_5_vl_7b_epd.py
|
||||
- name: qwen2-5-vl-32b
|
||||
os: linux-aarch64-a3-4
|
||||
tests: tests/e2e/nightly/single_node/models/test_qwen2_5_vl_32b.py
|
||||
|
||||
2
.github/workflows/scripts/config.yaml
vendored
2
.github/workflows/scripts/config.yaml
vendored
@@ -126,6 +126,8 @@ e2e-multicard-2-cards:
|
||||
estimated_time: 1050
|
||||
- name: tests/e2e/multicard/2-cards/test_single_request_aclgraph.py
|
||||
estimated_time: 215
|
||||
- name: tests/e2e/multicard/2-cards/test_disaggregated_encoder.py
|
||||
estimated_time: 90
|
||||
|
||||
e2e-multicard-4-cards:
|
||||
# TODO: recover skipped tests
|
||||
|
||||
206
examples/disaggregated_encoder/disagg_1e1pd_example.sh
Normal file
206
examples/disaggregated_encoder/disagg_1e1pd_example.sh
Normal file
@@ -0,0 +1,206 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
declare -a PIDS=()
|
||||
|
||||
###############################################################################
|
||||
# Configuration -- override via env before running
|
||||
###############################################################################
|
||||
MODEL="${MODEL:-Qwen/Qwen2.5-VL-7B-Instruct}"
|
||||
LOG_PATH="${LOG_PATH:-./logs}"
|
||||
mkdir -p $LOG_PATH
|
||||
|
||||
ENCODE_PORT="${ENCODE_PORT:-19534}"
|
||||
PREFILL_DECODE_PORT="${PREFILL_DECODE_PORT:-19535}"
|
||||
PROXY_PORT="${PROXY_PORT:-10001}"
|
||||
|
||||
CARD_E="${CARD_E:-0}"
|
||||
CARD_PD="${CARD_PD:-1}"
|
||||
|
||||
EC_SHARED_STORAGE_PATH="${EC_SHARED_STORAGE_PATH:-/tmp/ec_cache}"
|
||||
TIMEOUT_SECONDS="${TIMEOUT_SECONDS:-12000}" # wait_for_server timeout
|
||||
|
||||
NUM_PROMPTS="${NUM_PROMPTS:-100}" # number of prompts to send in benchmark
|
||||
|
||||
###############################################################################
|
||||
# Helpers
|
||||
###############################################################################
|
||||
# Find the git repository root directory
|
||||
VLLM_ROOT="/vllm-workspace/vllm"
|
||||
|
||||
START_TIME=$(date +"%Y%m%d_%H%M%S")
|
||||
ENC_LOG=$LOG_PATH/encoder_${START_TIME}.log
|
||||
PD_LOG=$LOG_PATH/pd_${START_TIME}.log
|
||||
PROXY_LOG=$LOG_PATH/proxy_${START_TIME}.log
|
||||
|
||||
wait_for_server() {
|
||||
local port=$1
|
||||
timeout "$TIMEOUT_SECONDS" bash -c "
|
||||
until curl -s localhost:$port/v1/chat/completions > /dev/null; do
|
||||
sleep 1
|
||||
done" && return 0 || return 1
|
||||
}
|
||||
|
||||
# Cleanup function
|
||||
cleanup() {
|
||||
echo "Stopping everything…"
|
||||
trap - INT TERM USR1 # prevent re-entrancy
|
||||
|
||||
# Kill all tracked PIDs
|
||||
for pid in "${PIDS[@]}"; do
|
||||
if kill -0 "$pid" 2>/dev/null; then
|
||||
echo "Killing process $pid"
|
||||
kill "$pid" 2>/dev/null
|
||||
fi
|
||||
done
|
||||
|
||||
# Wait a moment for graceful shutdown
|
||||
sleep 2
|
||||
|
||||
# Force kill any remaining processes
|
||||
for pid in "${PIDS[@]}"; do
|
||||
if kill -0 "$pid" 2>/dev/null; then
|
||||
echo "Force killing process $pid"
|
||||
kill -9 "$pid" 2>/dev/null
|
||||
fi
|
||||
done
|
||||
|
||||
# Kill the entire process group as backup
|
||||
kill -- -$$ 2>/dev/null
|
||||
|
||||
echo "All processes stopped."
|
||||
exit 0
|
||||
}
|
||||
|
||||
trap cleanup INT
|
||||
trap cleanup USR1
|
||||
trap cleanup TERM
|
||||
|
||||
# clear previous cache
|
||||
echo "remove previous ec cache folder"
|
||||
rm -rf $EC_SHARED_STORAGE_PATH
|
||||
|
||||
echo "make ec cache folder"
|
||||
mkdir -p $EC_SHARED_STORAGE_PATH
|
||||
|
||||
###############################################################################
|
||||
# Encoder worker
|
||||
###############################################################################
|
||||
ASCEND_RT_VISIBLE_DEVICES="$CARD_E" vllm serve "$MODEL" \
|
||||
--gpu-memory-utilization 0.01 \
|
||||
--port "$ENCODE_PORT" \
|
||||
--enforce-eager \
|
||||
--enable-request-id-headers \
|
||||
--no-enable-prefix-caching \
|
||||
--max-num-batched-tokens 114688 \
|
||||
--max-num-seqs 128 \
|
||||
--ec-transfer-config '{
|
||||
"ec_connector": "ECExampleConnector",
|
||||
"ec_role": "ec_producer",
|
||||
"ec_connector_extra_config": {
|
||||
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
|
||||
}
|
||||
}' \
|
||||
>"${ENC_LOG}" 2>&1 &
|
||||
|
||||
PIDS+=($!)
|
||||
|
||||
###############################################################################
|
||||
# Prefill+Decode worker
|
||||
###############################################################################
|
||||
ASCEND_RT_VISIBLE_DEVICES="$CARD_PD" vllm serve "$MODEL" \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--port "$PREFILL_DECODE_PORT" \
|
||||
--enforce-eager \
|
||||
--enable-request-id-headers \
|
||||
--max-num-seqs 128 \
|
||||
--ec-transfer-config '{
|
||||
"ec_connector": "ECExampleConnector",
|
||||
"ec_role": "ec_consumer",
|
||||
"ec_connector_extra_config": {
|
||||
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
|
||||
}
|
||||
}' \
|
||||
>"${PD_LOG}" 2>&1 &
|
||||
|
||||
PIDS+=($!)
|
||||
|
||||
# Wait for workers
|
||||
wait_for_server $ENCODE_PORT
|
||||
wait_for_server $PREFILL_DECODE_PORT
|
||||
|
||||
###############################################################################
|
||||
# Proxy
|
||||
###############################################################################
|
||||
python ./disagg_epd_proxy.py \
|
||||
--host "0.0.0.0" \
|
||||
--port "$PROXY_PORT" \
|
||||
--encode-servers-urls "http://localhost:$ENCODE_PORT" \
|
||||
--prefill-servers-urls "disable" \
|
||||
--decode-servers-urls "http://localhost:$PREFILL_DECODE_PORT" \
|
||||
>"${PROXY_LOG}" 2>&1 &
|
||||
|
||||
PIDS+=($!)
|
||||
|
||||
wait_for_server $PROXY_PORT
|
||||
echo "All services are up!"
|
||||
|
||||
###############################################################################
|
||||
# Single request with local image
|
||||
###############################################################################
|
||||
echo "Running single request with local image (non-stream)..."
|
||||
echo "Running single request with local image (non-stream)..."
|
||||
|
||||
base64_image=$(base64 -w 0 "${VLLM_ROOT}/tests/v1/ec_connector/integration/hato.jpg")
|
||||
|
||||
cat > /tmp/request.json << EOF
|
||||
{
|
||||
"model": "${MODEL}",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/jpg;base64,${base64_image}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What is in this image?"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
EOF
|
||||
|
||||
curl http://127.0.0.1:${PROXY_PORT}/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d @/tmp/request.json
|
||||
|
||||
rm -f /tmp/request.json
|
||||
|
||||
###############################################################################
|
||||
# Benchmark
|
||||
###############################################################################
|
||||
echo "Running benchmark (stream)..."
|
||||
vllm bench serve \
|
||||
--model $MODEL \
|
||||
--backend openai-chat \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name random-mm \
|
||||
--seed 0 \
|
||||
--num-prompts $NUM_PROMPTS \
|
||||
--port $PROXY_PORT
|
||||
|
||||
PIDS+=($!)
|
||||
|
||||
# cleanup
|
||||
echo "cleanup..."
|
||||
cleanup
|
||||
749
examples/disaggregated_encoder/disagg_epd_proxy.py
Normal file
749
examples/disaggregated_encoder/disagg_epd_proxy.py
Normal file
@@ -0,0 +1,749 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
disagg_encoder_proxy.py
|
||||
|
||||
Proxy that routes OpenAI-compatible “/v1/chat/completions” requests to two
|
||||
clusters:
|
||||
• encode (multimodal feature extraction)
|
||||
• decode (language-model inference)
|
||||
|
||||
For MM input we:
|
||||
1. Extract *every* image/audio item.
|
||||
2. Fire N concurrent requests to the encoder cluster
|
||||
(one request per item, with **all text removed**).
|
||||
3. Wait for all of them to succeed.
|
||||
4. Forward the *original* request to a decode server.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from enum import Enum
|
||||
|
||||
import aiohttp
|
||||
import uvicorn
|
||||
from aiohttp import ClientResponse
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
###############################################################################
|
||||
# FastAPI app & global state
|
||||
###############################################################################
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(levelname)s: %(message)s")
|
||||
logger = logging.getLogger("proxy")
|
||||
|
||||
app = FastAPI()
|
||||
encode_session: aiohttp.ClientSession | None = None
|
||||
prefill_session: aiohttp.ClientSession | None = None
|
||||
decode_session: aiohttp.ClientSession | None = None
|
||||
|
||||
###############################################################################
|
||||
# Utils
|
||||
###############################################################################
|
||||
|
||||
|
||||
MM_TYPES = {"image_url", "audio_url", "input_audio"}
|
||||
|
||||
|
||||
class EncoderDispatchMode(str, Enum):
|
||||
SINGLE = "single"
|
||||
FANOUT = "fanout"
|
||||
|
||||
|
||||
def extract_mm_items(request_data: dict) -> list[dict]:
|
||||
"""
|
||||
Return *all* image/audio items that appear anywhere in `messages`.
|
||||
|
||||
Each returned dict looks like:
|
||||
{ "type": "image_url", "image_url": {...} }
|
||||
"""
|
||||
items: list[dict] = []
|
||||
for msg in request_data.get("messages", []):
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
|
||||
for item in content:
|
||||
if item.get("type") in MM_TYPES:
|
||||
items.append(item)
|
||||
return items
|
||||
|
||||
|
||||
async def _encode_fanout(
|
||||
orig_request: dict,
|
||||
e_urls: list[str],
|
||||
req_id: str,
|
||||
):
|
||||
logger.info("[%s] Processing multimodal items...", req_id)
|
||||
|
||||
mm_items = extract_mm_items(orig_request)
|
||||
if not mm_items:
|
||||
logger.info("[%s] No multimodal items, skipping encoder", req_id)
|
||||
return # nothing to do
|
||||
|
||||
logger.info("[%s] got %d multimodal items...", req_id, len(mm_items))
|
||||
|
||||
tasks = []
|
||||
|
||||
# Round-robin over encode servers to distribute load a bit
|
||||
url_cycle = (e_urls[i % len(e_urls)] for i in range(len(mm_items)))
|
||||
|
||||
for idx, (item, target_url) in enumerate(zip(mm_items, url_cycle)):
|
||||
# Derive a *child* request id: <parent>:<index>:<random-short>
|
||||
child_req_id = f"{req_id}:{idx}:{uuid.uuid4().hex[:6]}"
|
||||
headers = {"x-request-id": child_req_id}
|
||||
|
||||
encoder_req = {
|
||||
# You *may* need to keep additional fields
|
||||
"model": orig_request.get("model"),
|
||||
"messages": [
|
||||
{"role": "user", "content": [item]},
|
||||
],
|
||||
# Only need 1 token so the server actually runs the encoder path
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
}
|
||||
if encode_session is None:
|
||||
raise HTTPException(status_code=500, detail="Encode session not initialized")
|
||||
tasks.append(
|
||||
encode_session.post(
|
||||
f"{target_url}/v1/chat/completions",
|
||||
json=encoder_req,
|
||||
headers=headers,
|
||||
)
|
||||
)
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Fail fast if any sub-request failed
|
||||
for idx, r in enumerate(results):
|
||||
if isinstance(r, Exception):
|
||||
logger.error(
|
||||
"[%s] Encoder request #%d raised exception: %s",
|
||||
req_id,
|
||||
idx,
|
||||
r,
|
||||
exc_info=r,
|
||||
)
|
||||
error_detail = str(r)
|
||||
if hasattr(r, "status"):
|
||||
error_detail = f"Status: {r.status}, Error: {error_detail}"
|
||||
elif hasattr(r, "status_code"):
|
||||
error_detail = f"Status: {r.status_code}, Error: {error_detail}"
|
||||
raise HTTPException(status_code=502, detail=f"Encoder request failed: {error_detail}")
|
||||
if isinstance(r, ClientResponse):
|
||||
if hasattr(r, "status") and r.status != 200:
|
||||
try:
|
||||
detail = await r.text()
|
||||
except Exception:
|
||||
detail = "<unable to read body>"
|
||||
logger.error(
|
||||
"[%s] Encoder request #%d returned status %s: %s",
|
||||
req_id,
|
||||
idx,
|
||||
r.status,
|
||||
detail,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=r.status,
|
||||
detail=f"Encoder request failed: {detail}",
|
||||
)
|
||||
|
||||
logger.info("[%s] All %d encoder requests completed successfully", req_id, len(mm_items))
|
||||
|
||||
|
||||
async def _encode_single_request(
|
||||
orig_request: dict,
|
||||
e_url: str,
|
||||
req_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
1. Build one request *per MM item* with all text removed.
|
||||
2. Send them concurrently to the encode cluster.
|
||||
3. Raise if any of them fails.
|
||||
"""
|
||||
logger.info("[%s] Processing multimodal items...", req_id)
|
||||
|
||||
request_data = copy.deepcopy(orig_request)
|
||||
headers = {"x-request-id": req_id}
|
||||
request_data["max_tokens"] = 1
|
||||
request_data["stream"] = False
|
||||
request_data.pop("stream_options", None)
|
||||
if "max_completion_tokens" in request_data:
|
||||
request_data["max_completion_tokens"] = 1
|
||||
|
||||
try:
|
||||
if encode_session is None:
|
||||
raise HTTPException(status_code=500, detail="Encode session not initialized")
|
||||
|
||||
encode_response = await encode_session.post(f"{e_url}/v1/chat/completions", json=request_data, headers=headers)
|
||||
encode_response.raise_for_status()
|
||||
|
||||
if encode_response.status != 200:
|
||||
encode_text = await encode_response.text()
|
||||
raise HTTPException(
|
||||
status_code=encode_response.status,
|
||||
detail={"error": "Encoder request failed", "message": encode_text},
|
||||
)
|
||||
logger.debug("Encoder processing completed successfully for req_id: %s", req_id)
|
||||
|
||||
return encode_response
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Encoder processing failed: %s", str(e))
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "Encoder processing error", "message": str(e)},
|
||||
) from e
|
||||
|
||||
logger.info("[%s] Encoder request completed successfully", req_id)
|
||||
|
||||
|
||||
async def fanout_encoder_primer(
|
||||
orig_request: dict,
|
||||
req_id: str,
|
||||
):
|
||||
mode = app.state.encoder_dispatch_mode
|
||||
|
||||
if mode == EncoderDispatchMode.SINGLE:
|
||||
e_url = random.choice(app.state.e_urls)
|
||||
await _encode_single_request(orig_request, e_url, req_id)
|
||||
|
||||
elif mode == EncoderDispatchMode.FANOUT:
|
||||
await _encode_fanout(orig_request, app.state.e_urls, req_id)
|
||||
|
||||
else:
|
||||
raise RuntimeError(f"Unknown encoder dispatch mode: {mode}")
|
||||
|
||||
|
||||
async def maybe_prefill(
|
||||
req_data: dict,
|
||||
p_url: str,
|
||||
req_id: str,
|
||||
) -> dict:
|
||||
"""
|
||||
- Do prefill-only task if p_url exist;
|
||||
- Return modified request data with kv transfer params (for nixl connector)
|
||||
- Else, skip and return the original request data for decode
|
||||
"""
|
||||
if p_url:
|
||||
logger.info("[%s] Processing through prefill: %s", req_id, p_url)
|
||||
|
||||
prefill_response = await process_prefill_stage(req_data, p_url, req_id)
|
||||
if isinstance(prefill_response, ClientResponse):
|
||||
# for nixl connector to facilitate kv transfer...
|
||||
prefill_response_json = await prefill_response.json()
|
||||
kv_transfer_params = prefill_response_json.get("kv_transfer_params", {})
|
||||
if kv_transfer_params:
|
||||
req_data["kv_transfer_params"] = kv_transfer_params
|
||||
return req_data
|
||||
else:
|
||||
return req_data
|
||||
|
||||
|
||||
async def process_prefill_stage(
|
||||
req_data: dict,
|
||||
p_url: str,
|
||||
req_id: str,
|
||||
) -> ClientResponse:
|
||||
"""Process request through Prefill stage and return kv_transfer_params"""
|
||||
logger.info("[%s] Sending prefill request to: %s", req_id, p_url)
|
||||
|
||||
prefill_request = req_data.copy()
|
||||
prefill_request["kv_transfer_params"] = {
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
"remote_engine_id": None,
|
||||
"remote_block_ids": None,
|
||||
"remote_host": None,
|
||||
"remote_port": None,
|
||||
}
|
||||
prefill_request["stream"] = False
|
||||
prefill_request["max_tokens"] = 1
|
||||
if "max_completion_tokens" in prefill_request:
|
||||
prefill_request["max_completion_tokens"] = 1
|
||||
if "stream_options" in prefill_request:
|
||||
del prefill_request["stream_options"]
|
||||
|
||||
headers = {"x-request-id": req_id}
|
||||
try:
|
||||
if prefill_session is None:
|
||||
raise HTTPException(status_code=500, detail="Prefill session not initialized")
|
||||
|
||||
prefill_response = await prefill_session.post(
|
||||
f"{p_url}/v1/chat/completions", json=prefill_request, headers=headers
|
||||
)
|
||||
prefill_response.raise_for_status()
|
||||
|
||||
if prefill_response.status != 200:
|
||||
error_text = await prefill_response.text()
|
||||
logger.error(
|
||||
"[%s] Prefill request failed with status %d: %s",
|
||||
req_id,
|
||||
prefill_response.status,
|
||||
error_text,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=prefill_response.status,
|
||||
detail={"error": "Prefill request failed", "message": error_text},
|
||||
)
|
||||
logger.info("[%s] Prefill request completed successfully", req_id)
|
||||
|
||||
return prefill_response
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Prefill processing failed: %s", str(e))
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "Prefill processing error", "message": str(e)},
|
||||
) from e
|
||||
|
||||
|
||||
def has_mm_input(request_data: dict):
|
||||
if "messages" not in request_data:
|
||||
return False
|
||||
for message in request_data["messages"]:
|
||||
if not isinstance(message.get("content"), list):
|
||||
continue
|
||||
for content_item in message["content"]:
|
||||
if content_item.get("type") in ["image_url", "audio_url", "input_audio"]:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Middleware for request/response logging
|
||||
###############################################################################
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def log_requests(request: Request, call_next):
|
||||
"""Middleware to log all incoming requests and responses"""
|
||||
req_id = request.headers.get("x-request-id", str(uuid.uuid4()))
|
||||
|
||||
# Log incoming request
|
||||
logger.info(
|
||||
">>> [%s] %s %s from %s",
|
||||
req_id,
|
||||
request.method,
|
||||
request.url.path,
|
||||
request.client.host if request.client else "unknown",
|
||||
)
|
||||
|
||||
try:
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Log response
|
||||
logger.info(
|
||||
"<<< [%s] %s %s completed with status %d",
|
||||
req_id,
|
||||
request.method,
|
||||
request.url.path,
|
||||
response.status_code,
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
# Log errors
|
||||
logger.exception(
|
||||
"!!! [%s] %s %s failed with error: %s",
|
||||
req_id,
|
||||
request.method,
|
||||
request.url.path,
|
||||
str(e),
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
###############################################################################
|
||||
# FastAPI lifecycle
|
||||
###############################################################################
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def on_startup() -> None:
|
||||
global encode_session, prefill_session, decode_session
|
||||
timeout = aiohttp.ClientTimeout(total=100_000)
|
||||
connector = aiohttp.TCPConnector(limit=0, force_close=False, keepalive_timeout=0)
|
||||
encode_session = aiohttp.ClientSession(timeout=timeout, connector=connector)
|
||||
if app.state.p_urls:
|
||||
# only setup if prefill instance(s) exist
|
||||
prefill_session = aiohttp.ClientSession(timeout=timeout, connector=connector)
|
||||
decode_session = aiohttp.ClientSession(timeout=timeout, connector=connector)
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def on_shutdown() -> None:
|
||||
global encode_session, prefill_session, decode_session
|
||||
if encode_session:
|
||||
await encode_session.close()
|
||||
if prefill_session:
|
||||
await prefill_session.close()
|
||||
if decode_session:
|
||||
await decode_session.close()
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Core forwarding
|
||||
###############################################################################
|
||||
|
||||
|
||||
async def forward_non_stream(req_data: dict, req_id: str, p_url: str, d_url: str) -> dict:
|
||||
try:
|
||||
# Step 1: Process through Encoder instance (if has MM input)
|
||||
async def run_encoder():
|
||||
await fanout_encoder_primer(req_data, req_id)
|
||||
|
||||
if has_mm_input(req_data):
|
||||
await non_stream_retry_wrap(run_encoder)
|
||||
|
||||
# Step 2: Process through Prefill instance
|
||||
async def run_prefill():
|
||||
return await maybe_prefill(req_data, p_url, req_id)
|
||||
|
||||
req_data = await non_stream_retry_wrap(run_prefill)
|
||||
|
||||
async def run_decode_non_stream():
|
||||
# Step 3: Process through Decode instance
|
||||
logger.info("[%s] Forwarding to decode: %s", req_id, d_url)
|
||||
headers = {"x-request-id": req_id}
|
||||
|
||||
# Non-streaming response
|
||||
if decode_session is None:
|
||||
raise HTTPException(status_code=500, detail="Decode session not initialized")
|
||||
|
||||
async with decode_session.post(f"{d_url}/v1/chat/completions", json=req_data, headers=headers) as resp:
|
||||
resp.raise_for_status()
|
||||
return await resp.json()
|
||||
|
||||
return await non_stream_retry_wrap(run_decode_non_stream)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("[%s] Error in forward_non_stream: %s", req_id, str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Proxy error: {str(e)}") from e
|
||||
|
||||
|
||||
async def stream_retry_wrap(forward_func, max_retries: int = 3, delay: float = 0.001):
|
||||
last_exc = None
|
||||
first_chunk_sent = False
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async for chunk in forward_func():
|
||||
first_chunk_sent = True
|
||||
yield chunk
|
||||
return
|
||||
except Exception as e:
|
||||
if first_chunk_sent:
|
||||
raise
|
||||
if isinstance(e, HTTPException) and e.status_code < 500:
|
||||
raise
|
||||
last_exc = e
|
||||
logger.warning(
|
||||
"attempt %s / %s failed retrying... ",
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
)
|
||||
await asyncio.sleep(delay * (attempt + 1))
|
||||
|
||||
raise RuntimeError(f"all {max_retries} retries failed.") from last_exc
|
||||
|
||||
|
||||
async def non_stream_retry_wrap(forward_func, max_retries: int = 3, delay: float = 0.001):
|
||||
last_exc = None
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
result = await forward_func()
|
||||
return result
|
||||
except Exception as e:
|
||||
if isinstance(e, HTTPException) and e.status_code < 500:
|
||||
raise
|
||||
last_exc = e
|
||||
logger.warning(
|
||||
"attempt %s / %s failed retrying... ",
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
)
|
||||
await asyncio.sleep(delay * (attempt + 1))
|
||||
raise RuntimeError(f"all {max_retries} retries failed.") from last_exc
|
||||
|
||||
|
||||
async def forward_stream(req_data: dict, req_id: str, p_url: str, d_url: str) -> AsyncIterator[str]:
|
||||
try:
|
||||
# Step 1: Process through Encoder instance (if has MM input)
|
||||
async def run_encoder():
|
||||
await fanout_encoder_primer(req_data, req_id)
|
||||
|
||||
if has_mm_input(req_data):
|
||||
await non_stream_retry_wrap(run_encoder)
|
||||
|
||||
# Step 2: Process through Prefill instance
|
||||
async def run_prefill():
|
||||
return await maybe_prefill(req_data, p_url, req_id)
|
||||
|
||||
req_data = await non_stream_retry_wrap(run_prefill)
|
||||
|
||||
async def run_decode_stream():
|
||||
# Step 3: Process through Decode instance
|
||||
logger.info("[%s] Starting streaming from decode: %s", req_id, d_url)
|
||||
headers = {"x-request-id": req_id}
|
||||
|
||||
# Streaming response
|
||||
if decode_session is None:
|
||||
raise HTTPException(status_code=500, detail="Decode session not initialized")
|
||||
|
||||
async with decode_session.post(
|
||||
f"{d_url}/v1/chat/completions",
|
||||
json=req_data,
|
||||
headers=headers,
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.content.iter_chunked(1024):
|
||||
if chunk:
|
||||
yield chunk.decode("utf-8", errors="ignore")
|
||||
|
||||
logger.info("[%s] Streaming completed", req_id)
|
||||
|
||||
async for chunk in stream_retry_wrap(run_decode_stream):
|
||||
yield chunk
|
||||
|
||||
except HTTPException:
|
||||
logger.exception("[%s] HTTPException in forward_stream", req_id)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("[%s] Error in forward_stream: %s", req_id, str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Proxy streaming error: {str(e)}") from e
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Public routes
|
||||
###############################################################################
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request):
|
||||
try:
|
||||
req_data = await request.json()
|
||||
req_id = request.headers.get("x-request-id", str(uuid.uuid4()))
|
||||
|
||||
p_url = random.choice(app.state.p_urls) if app.state.p_urls else None
|
||||
d_url = random.choice(app.state.d_urls)
|
||||
|
||||
is_streaming = req_data.get("stream", False)
|
||||
|
||||
if is_streaming:
|
||||
return StreamingResponse(
|
||||
forward_stream(req_data, req_id, p_url, d_url),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
result = await forward_non_stream(req_data, req_id, p_url, d_url)
|
||||
return JSONResponse(content=result)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error in chat_completions endpoint: %s", str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Request processing error: {str(e)}") from e
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models():
|
||||
if decode_session is None:
|
||||
raise HTTPException(status_code=500, detail="Decode session not initialized")
|
||||
async with decode_session.get(f"{app.state.d_urls[0]}/v1/models") as resp:
|
||||
resp.raise_for_status()
|
||||
return await resp.json()
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
async def healthy(urls, session):
|
||||
if not urls:
|
||||
return "empty"
|
||||
for u in urls:
|
||||
try:
|
||||
if session is None:
|
||||
return "unhealthy"
|
||||
async with session.get(f"{u}/health") as resp:
|
||||
resp.raise_for_status()
|
||||
except Exception:
|
||||
return "unhealthy"
|
||||
return "healthy"
|
||||
|
||||
e_status, p_status, d_status = await asyncio.gather(
|
||||
healthy(app.state.e_urls, encode_session),
|
||||
healthy(app.state.p_urls, prefill_session),
|
||||
healthy(app.state.d_urls, decode_session),
|
||||
)
|
||||
|
||||
overall_healthy = all(status != "unhealthy" for status in (e_status, p_status, d_status))
|
||||
|
||||
status_code = 200 if overall_healthy else 503
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"proxy": "healthy",
|
||||
"encode_cluster": e_status,
|
||||
"prefill_cluster": p_status,
|
||||
"decode_cluster": d_status,
|
||||
},
|
||||
status_code=status_code,
|
||||
)
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Simple profiler fan-out (unchanged except for sessions)
|
||||
###############################################################################
|
||||
|
||||
|
||||
async def _post_if_available(
|
||||
session: aiohttp.ClientSession,
|
||||
url: str,
|
||||
payload: dict,
|
||||
headers: dict,
|
||||
) -> dict | None:
|
||||
"""
|
||||
POST `payload` to `url`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
• The decoded JSON body on success (2xx)
|
||||
• None if the endpoint does not exist (404)
|
||||
• Raises for anything else.
|
||||
"""
|
||||
try:
|
||||
if session is None:
|
||||
return None
|
||||
resp = await session.post(url, json=payload, headers=headers)
|
||||
if resp.status == 404: # profiling disabled on that server
|
||||
logger.warning("Profiling endpoint missing on %s", url)
|
||||
return None
|
||||
resp.raise_for_status()
|
||||
return await resp.json(content_type=None)
|
||||
except aiohttp.ClientResponseError as exc:
|
||||
# Pass 404 through the branch above, re-raise everything else
|
||||
if exc.status == 404:
|
||||
logger.warning("Profiling endpoint missing on %s", url)
|
||||
return None
|
||||
raise
|
||||
except Exception:
|
||||
# Network errors etc.: propagate
|
||||
raise
|
||||
|
||||
|
||||
async def _profile_cmd(cmd: str, payload: dict, e_url: str, p_url: str, d_url: str):
|
||||
"""
|
||||
Fire & forget to both clusters, tolerate 404.
|
||||
"""
|
||||
headers = {"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY', '')}"}
|
||||
|
||||
encode_task = _post_if_available(encode_session, f"{e_url}/{cmd}_profile", payload, headers)
|
||||
prefill_task = (
|
||||
_post_if_available(prefill_session, f"{p_url}/{cmd}_profile", payload, headers)
|
||||
if p_url is not None
|
||||
else asyncio.sleep(0)
|
||||
)
|
||||
decode_task = _post_if_available(decode_session, f"{d_url}/{cmd}_profile", payload, headers)
|
||||
|
||||
encode_res, prefill_res, decode_res = await asyncio.gather(encode_task, prefill_task, decode_task)
|
||||
|
||||
# If *all* clusters said “I don’t have that route”, surface an error
|
||||
if encode_res is prefill_res is decode_res is None:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Profiling endpoints are disabled on all clusters",
|
||||
)
|
||||
|
||||
return {
|
||||
"encode": encode_res, # may be None
|
||||
"prefill": prefill_res, # may be None
|
||||
"decode": decode_res, # may be None
|
||||
}
|
||||
|
||||
|
||||
@app.post("/start_profile")
|
||||
async def start_profile(request: Request):
|
||||
body = await request.json()
|
||||
# TODO: handle multi urls properly
|
||||
e_url = random.choice(app.state.e_urls)
|
||||
p_url = random.choice(app.state.p_urls) if app.state.p_urls else None
|
||||
d_url = random.choice(app.state.d_urls)
|
||||
return await _profile_cmd("start", body, e_url, p_url, d_url)
|
||||
|
||||
|
||||
@app.post("/stop_profile")
|
||||
async def stop_profile(request: Request):
|
||||
body = await request.json()
|
||||
# TODO: handle multi urls properly
|
||||
e_url = random.choice(app.state.e_urls)
|
||||
p_url = random.choice(app.state.p_urls) if app.state.p_urls else None
|
||||
d_url = random.choice(app.state.d_urls)
|
||||
return await _profile_cmd("stop", body, e_url, p_url, d_url)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument(
|
||||
"--encode-servers-urls",
|
||||
required=True,
|
||||
help='Comma-separated encode URLs ("http://e1:8001,http://e2:8001")',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefill-servers-urls",
|
||||
required=True,
|
||||
help='Comma-separated prefill URLs ("http://p1:8003,http://p2:8004") to enable E->P->D, '
|
||||
'set "disable" or "none" to enable E->PD',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decode-servers-urls",
|
||||
required=True,
|
||||
help='Comma-separated decode URLs ("http://d1:8005,http://d2:8006")',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-dispatch-mode",
|
||||
choices=["single", "fanout"],
|
||||
default="single",
|
||||
help="Encoder dispatch mode: single (one request) or fanout (per-MM-item)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
app.state.e_urls = [u.strip() for u in args.encode_servers_urls.split(",") if u.strip()]
|
||||
app.state.d_urls = [u.strip() for u in args.decode_servers_urls.split(",") if u.strip()]
|
||||
# handle prefill instances
|
||||
if args.prefill_servers_urls.lower() in ("disable", "none", ""):
|
||||
app.state.p_urls = []
|
||||
logger.info("Disaggregated prefill phase explicitly disabled by user. Running E + PD...")
|
||||
else:
|
||||
app.state.p_urls = [u.strip() for u in args.prefill_servers_urls.split(",") if u.strip()]
|
||||
logger.info("Disaggregated prefill phase is enabled. Running E + P + D...")
|
||||
|
||||
app.state.encoder_dispatch_mode = EncoderDispatchMode(args.encoder_dispatch_mode)
|
||||
|
||||
logger.info("Proxy listening on %s:%s", args.host, args.port)
|
||||
logger.info("Encode servers: %s", app.state.e_urls)
|
||||
logger.info("Prefill instances %s", app.state.p_urls)
|
||||
logger.info("Decode servers: %s", app.state.d_urls)
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="info",
|
||||
loop="uvloop",
|
||||
access_log=True,
|
||||
)
|
||||
@@ -18,6 +18,7 @@
|
||||
#
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
import gc
|
||||
import json
|
||||
@@ -27,11 +28,15 @@ import os
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Tuple, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
import psutil
|
||||
import pytest
|
||||
import requests
|
||||
import torch
|
||||
@@ -80,6 +85,10 @@ logger = logging.getLogger(__name__)
|
||||
_TEST_DIR = os.path.dirname(__file__)
|
||||
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "long_prompt.txt")]
|
||||
|
||||
DISAGG_EPD_PROXY_SCRIPT = Path(
|
||||
__file__
|
||||
).parent.parent.parent / "examples" / "disaggregated_encoder" / "disagg_epd_proxy.py"
|
||||
|
||||
|
||||
def _check_npu_memory_worker(target_free_percentage: float, max_wait_seconds: float):
|
||||
import torch_npu # type: ignore
|
||||
@@ -441,6 +450,216 @@ class RemoteOpenAIServer:
|
||||
**kwargs)
|
||||
|
||||
|
||||
class RemoteEPDServer(RemoteOpenAIServer):
|
||||
def _start_server(self, model: str, server_cmd: list[str],
|
||||
env_dict: Optional[dict[str, str]]) -> None:
|
||||
"""Subclasses override this method to customize server process launch
|
||||
"""
|
||||
raise NotImplementedError("RemoteEPDServer should use _start_server_with_prefix instead")
|
||||
|
||||
def __init__(self,
|
||||
vllm_serve_args: Union[list[str], list[list[str]]],
|
||||
server_host: str = '0.0.0.0',
|
||||
env_dict: Optional[dict[str, str]] = None,
|
||||
max_wait_seconds: Optional[float] = 2800) -> None:
|
||||
|
||||
self._proc_list = []
|
||||
|
||||
self.env_dict: dict[str, str] = {}
|
||||
if env_dict is not None:
|
||||
self.env_dict.update(env_dict)
|
||||
|
||||
self.env_dict['VLLM_ALLOW_LONG_MAX_MODEL_LEN'] = "1"
|
||||
self.env_dict['VLLM_USE_V1'] = "1"
|
||||
self.env_dict['PYTORCH_NPU_ALLOC_CONF'] = "expandable_segments:True"
|
||||
self.env_dict['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
|
||||
|
||||
self.vllm_serve_args_list = []
|
||||
self.health_url_list = []
|
||||
self.host = server_host
|
||||
|
||||
if isinstance(vllm_serve_args, list):
|
||||
if not all(isinstance(item, list) for item in vllm_serve_args):
|
||||
args_copy = copy.deepcopy(vllm_serve_args)
|
||||
self.vllm_serve_args_list.append([str(arg) for arg in args_copy])
|
||||
else:
|
||||
self.vllm_serve_args_list = [
|
||||
[str(arg) for arg in sublist]
|
||||
for sublist in copy.deepcopy(vllm_serve_args)
|
||||
]
|
||||
else:
|
||||
raise RuntimeError("vllm_serves_args must be a list")
|
||||
|
||||
serve_arg_cmd = ["vllm", "serve"]
|
||||
|
||||
for i, vllm_serve_arg in enumerate(self.vllm_serve_args_list):
|
||||
self.env_dict['ASCEND_RT_VISIBLE_DEVICES'] = str(i)
|
||||
if isinstance(vllm_serve_arg, list):
|
||||
if "--port" not in vllm_serve_arg:
|
||||
raise ValueError("You have manually specified the port ")
|
||||
else:
|
||||
port_arg = "--port"
|
||||
try:
|
||||
index = vllm_serve_arg.index(port_arg)
|
||||
except ValueError:
|
||||
raise ValueError(f"--port not found in args: {vllm_serve_arg}")
|
||||
port_str = vllm_serve_arg[index + 1]
|
||||
self.port = int(port_str)
|
||||
else:
|
||||
vllm_serve_arg_str = str(vllm_serve_arg)
|
||||
if "--port" not in vllm_serve_arg_str:
|
||||
raise ValueError("You have manually specified the port ")
|
||||
else:
|
||||
raise ValueError(f"Unexpected type for vllm_serve_arg: {type(vllm_serve_arg)}")
|
||||
|
||||
self.health_url_list.append(super().url_for("health"))
|
||||
vllm_serve_arg = [*serve_arg_cmd, *vllm_serve_arg]
|
||||
proc = self._start_server_with_prefix(vllm_serve_arg, self.env_dict,
|
||||
f"[VLLM_{i}] ")
|
||||
self._proc_list.append(proc)
|
||||
|
||||
timeout_value = float(max_wait_seconds) if max_wait_seconds is not None else 2800.0
|
||||
super()._wait_for_multiple_servers([(self.host, url)
|
||||
for url in self.health_url_list],
|
||||
timeout=timeout_value)
|
||||
|
||||
def _poll(self) -> Optional[int]:
|
||||
return None
|
||||
|
||||
def _delete_shm(self) -> None:
|
||||
for i, arg in enumerate(self.vllm_serve_args_list):
|
||||
if "--ec-transfer-config" in arg:
|
||||
index = arg.index("--ec-transfer-config")
|
||||
config_str = arg[index + 1]
|
||||
config_dict = json.loads(config_str)
|
||||
ec_connector_extra_config = config_dict.get("ec_connector_extra_config", {})
|
||||
shm_path = ec_connector_extra_config.get("shared_storage_path")
|
||||
if shm_path:
|
||||
args = ["rm", "-r", "-f", str(shm_path)]
|
||||
print(f"delete shm_path is: {shm_path}")
|
||||
self._start_server_with_prefix(args, None, "[DELETE] ")
|
||||
|
||||
def _read_output(self, pipe, prefix):
|
||||
try:
|
||||
with pipe:
|
||||
for line in iter(pipe.readline, ''):
|
||||
if line:
|
||||
print(f"{prefix}: {line}", end='')
|
||||
|
||||
except Exception as e:
|
||||
print(f"error: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
def _start_server_with_prefix(self, server_cmd: list[str],
|
||||
env_dict: Optional[dict[str, str]], log_prefix: str):
|
||||
env = os.environ.copy()
|
||||
if env_dict is not None:
|
||||
env.update(env_dict)
|
||||
proc = subprocess.Popen(server_cmd,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
universal_newlines=True,
|
||||
bufsize=1)
|
||||
stdout_thread = threading.Thread(target=self._read_output,
|
||||
args=(proc.stdout, log_prefix),
|
||||
daemon=True)
|
||||
stderr_thread = threading.Thread(target=self._read_output,
|
||||
args=(proc.stderr, log_prefix),
|
||||
daemon=True)
|
||||
|
||||
stdout_thread.start()
|
||||
stderr_thread.start()
|
||||
return proc
|
||||
|
||||
def _terminate_server(self) -> None:
|
||||
"""kill process and its children"""
|
||||
print("vllm instance is stopping")
|
||||
for proc in self._proc_list:
|
||||
parent = psutil.Process(proc.pid)
|
||||
children = parent.children(recursive=True)
|
||||
for child in children:
|
||||
try:
|
||||
child.terminate()
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
|
||||
gone, still_alive = psutil.wait_procs(children, timeout=10)
|
||||
|
||||
for child in still_alive:
|
||||
try:
|
||||
child.kill()
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
|
||||
try:
|
||||
parent.terminate()
|
||||
parent.wait(timeout=10)
|
||||
except (psutil.NoSuchProcess, psutil.TimeoutExpired):
|
||||
try:
|
||||
parent.kill()
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry point."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit point - clean up all processes."""
|
||||
self._terminate_server()
|
||||
|
||||
|
||||
class DisaggEpdProxy(RemoteEPDServer):
|
||||
|
||||
def __init__(self,
|
||||
proxy_args: Optional[Union[list[str], str]] = None,
|
||||
env_dict: Optional[dict[str, str]] = None,
|
||||
server_host: str = '0.0.0.0',
|
||||
max_wait_seconds: Optional[float] = 2800) -> None:
|
||||
|
||||
if proxy_args is None:
|
||||
proxy_args_list: list[str] = []
|
||||
elif isinstance(proxy_args, str):
|
||||
proxy_args_list = shlex.split(proxy_args)
|
||||
else:
|
||||
proxy_args_list = proxy_args
|
||||
|
||||
self.proxy_args = proxy_args_list
|
||||
self.env_dict: dict[str, str] = {}
|
||||
if env_dict is not None:
|
||||
self.env_dict.update(env_dict)
|
||||
self._proc_list = list()
|
||||
self.host = server_host
|
||||
|
||||
print(f"proxy param is: {self.proxy_args}")
|
||||
proxy_cmd = ["python", str(DISAGG_EPD_PROXY_SCRIPT), *self.proxy_args]
|
||||
proc = self._start_server_with_prefix(proxy_cmd, self.env_dict, "[PROXY] ")
|
||||
self._proc_list.append(proc)
|
||||
|
||||
if "--port" not in self.proxy_args:
|
||||
raise ValueError("You have manually specified the port ")
|
||||
else:
|
||||
try:
|
||||
index = self.proxy_args.index("--port")
|
||||
except ValueError:
|
||||
raise ValueError("--port not found in proxy args")
|
||||
port_str = self.proxy_args[index + 1]
|
||||
self.port = int(port_str)
|
||||
|
||||
timeout_value = float(max_wait_seconds) if max_wait_seconds is not None else 2800.0
|
||||
super()._wait_for_multiple_servers(
|
||||
[(self.host, super().url_for("health"))], timeout=timeout_value)
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry point."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit point - clean up all processes."""
|
||||
super()._terminate_server()
|
||||
|
||||
|
||||
class VllmRunner:
|
||||
|
||||
def __init__(
|
||||
|
||||
71
tests/e2e/multicard/2-cards/test_disaggregated_encoder.py
Normal file
71
tests/e2e/multicard/2-cards/test_disaggregated_encoder.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import pytest
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
|
||||
from tests.e2e.conftest import DisaggEpdProxy, RemoteEPDServer
|
||||
from tools.send_mm_request import send_image_request
|
||||
|
||||
MODELS = [
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
]
|
||||
SHARED_STORAGE_PATH = "/dev/shm/epd/storage"
|
||||
TENSOR_PARALLELS = [1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS)
|
||||
async def test_models(model: str, tp_size: int) -> None:
|
||||
encode_port = get_open_port()
|
||||
pd_port = get_open_port()
|
||||
vllm_server_args = [
|
||||
[
|
||||
"--port",
|
||||
str(encode_port), "--model", model, "--gpu-memory-utilization",
|
||||
"0.01", "--tensor-parallel-size",
|
||||
str(tp_size), "--enforce-eager", "--no-enable-prefix-caching",
|
||||
"--max-model-len", "10000", "--max-num-batched-tokens", "10000",
|
||||
"--max-num-seqs", "1", "--ec-transfer-config",
|
||||
'{"ec_connector_extra_config":{"shared_storage_path":"' +
|
||||
SHARED_STORAGE_PATH +
|
||||
'"},"ec_connector":"ECExampleConnector","ec_role": "ec_producer"}'
|
||||
],
|
||||
[
|
||||
"--port",
|
||||
str(pd_port), "--model", model, "--gpu-memory-utilization", "0.95",
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size), "--enforce-eager", "--max-model-len", "10000",
|
||||
"--max-num-batched-tokens", "10000", "--max-num-seqs", "128",
|
||||
"--ec-transfer-config",
|
||||
'{"ec_connector_extra_config":{"shared_storage_path":"' +
|
||||
SHARED_STORAGE_PATH +
|
||||
'"},"ec_connector":"ECExampleConnector","ec_role": "ec_consumer"}'
|
||||
]
|
||||
]
|
||||
proxy_port = get_open_port()
|
||||
proxy_args = [
|
||||
"--host", "127.0.0.1", "--port",
|
||||
str(proxy_port), "--encode-servers-urls",
|
||||
f"http://localhost:{encode_port}", "--decode-servers-urls",
|
||||
f"http://localhost:{pd_port}", "--prefill-servers-urls", "disable"
|
||||
]
|
||||
|
||||
with RemoteEPDServer(vllm_serve_args=vllm_server_args) as _:
|
||||
with DisaggEpdProxy(proxy_args=proxy_args) as proxy:
|
||||
send_image_request(model, proxy)
|
||||
|
||||
110
tests/e2e/nightly/single_node/models/test_qwen2_5_vl_7b_epd.py
Normal file
110
tests/e2e/nightly/single_node/models/test_qwen2_5_vl_7b_epd.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import pytest
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
|
||||
from tests.e2e.conftest import DisaggEpdProxy, RemoteEPDServer
|
||||
from tools.aisbench import run_aisbench_cases
|
||||
|
||||
MODELS = [
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
]
|
||||
SHARED_STORAGE_PATH = "/dev/shm/epd/storage"
|
||||
TENSOR_PARALLELS = [1]
|
||||
|
||||
warmup_cases = [{
|
||||
"case_type": "performance",
|
||||
"dataset_path": "vllm-ascend/textvqa-perf-1080p",
|
||||
"request_conf": "vllm_api_stream_chat",
|
||||
"dataset_conf": "textvqa/textvqa_gen_base64",
|
||||
"num_prompts": 50,
|
||||
"max_out_len": 20,
|
||||
"batch_size": 32,
|
||||
"request_rate": 0,
|
||||
"baseline": 1,
|
||||
"threshold": 0.97
|
||||
}]
|
||||
aisbench_cases = [{
|
||||
"case_type": "accuracy",
|
||||
"dataset_path": "vllm-ascend/textvqa-lite",
|
||||
"request_conf": "vllm_api_stream_chat",
|
||||
"dataset_conf": "textvqa/textvqa_gen_base64",
|
||||
"max_out_len": 2048,
|
||||
"batch_size": 128,
|
||||
"baseline": 82.05,
|
||||
"threshold": 5
|
||||
}, {
|
||||
"case_type": "performance",
|
||||
"dataset_path": "vllm-ascend/textvqa-perf-1080p",
|
||||
"request_conf": "vllm_api_stream_chat",
|
||||
"dataset_conf": "textvqa/textvqa_gen_base64",
|
||||
"num_prompts": 512,
|
||||
"max_out_len": 256,
|
||||
"batch_size": 128,
|
||||
"request_rate": 0,
|
||||
"baseline": 1,
|
||||
"threshold": 0.97
|
||||
}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS)
|
||||
async def test_models(model: str, tp_size: int) -> None:
|
||||
encode_port = get_open_port()
|
||||
pd_port = get_open_port()
|
||||
vllm_server_args = [
|
||||
[
|
||||
"--port",
|
||||
str(encode_port), "--model", model, "--gpu-memory-utilization",
|
||||
"0.01", "--tensor-parallel-size",
|
||||
str(tp_size), "--enforce-eager", "--no-enable-prefix-caching",
|
||||
"--max-model-len", "10000", "--max-num-batched-tokens", "10000",
|
||||
"--max-num-seqs", "1", "--ec-transfer-config",
|
||||
'{"ec_connector_extra_config":{"shared_storage_path":"' +
|
||||
SHARED_STORAGE_PATH +
|
||||
'"},"ec_connector":"ECExampleConnector","ec_role": "ec_producer"}'
|
||||
],
|
||||
[
|
||||
"--port",
|
||||
str(pd_port), "--model", model, "--gpu-memory-utilization", "0.95",
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size), "--enforce-eager", "--max-model-len", "10000",
|
||||
"--max-num-batched-tokens", "10000", "--max-num-seqs", "128",
|
||||
"--ec-transfer-config",
|
||||
'{"ec_connector_extra_config":{"shared_storage_path":"' +
|
||||
SHARED_STORAGE_PATH +
|
||||
'"},"ec_connector":"ECExampleConnector","ec_role": "ec_consumer"}'
|
||||
]
|
||||
]
|
||||
proxy_port = get_open_port()
|
||||
proxy_args = [
|
||||
"--host", "127.0.0.1", "--port",
|
||||
str(proxy_port), "--encode-servers-urls",
|
||||
f"http://localhost:{encode_port}", "--decode-servers-urls",
|
||||
f"http://localhost:{pd_port}", "--prefill-servers-urls", "disable"
|
||||
]
|
||||
|
||||
with RemoteEPDServer(vllm_serve_args=vllm_server_args) as _:
|
||||
with DisaggEpdProxy(proxy_args=proxy_args) as _:
|
||||
# warm up
|
||||
run_aisbench_cases(model=model,
|
||||
port=proxy_port,
|
||||
aisbench_cases=warmup_cases)
|
||||
# aisbench test
|
||||
run_aisbench_cases(model, proxy_port, aisbench_cases)
|
||||
@@ -67,7 +67,7 @@ def main():
|
||||
print(f" - {pkg}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("✅ All Python packages have __init__.py files.")
|
||||
print("All Python packages have __init__.py files.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user