Sync from v0.13
This commit is contained in:
143
benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh
Normal file
143
benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh
Normal file
@@ -0,0 +1,143 @@
|
||||
#!/bin/bash
|
||||
|
||||
# benchmark the overhead of disaggregated prefill.
|
||||
# methodology:
|
||||
# - send all request to prefill vLLM instance. It will buffer KV cache.
|
||||
# - then send all request to decode instance.
|
||||
# - The TTFT of decode instance is the overhead.
|
||||
|
||||
set -ex
|
||||
|
||||
kill_gpu_processes() {
|
||||
# kill all processes on GPU.
|
||||
pgrep pt_main_thread | xargs -r kill -9
|
||||
pgrep python3 | xargs -r kill -9
|
||||
# vLLM now names the process with VLLM prefix after https://github.com/vllm-project/vllm/pull/21445
|
||||
pgrep VLLM | xargs -r kill -9
|
||||
sleep 10
|
||||
|
||||
# remove vllm config file
|
||||
rm -rf ~/.config/vllm
|
||||
|
||||
# Print the GPU memory usage
|
||||
# so that we know if all GPU processes are killed.
|
||||
gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0)
|
||||
# The memory usage should be 0 MB.
|
||||
echo "GPU 0 Memory Usage: $gpu_memory_usage MB"
|
||||
}
|
||||
|
||||
wait_for_server() {
|
||||
# wait for vllm server to start
|
||||
# return 1 if vllm server crashes
|
||||
local port=$1
|
||||
timeout 1200 bash -c "
|
||||
until curl -s localhost:${port}/v1/completions > /dev/null; do
|
||||
sleep 1
|
||||
done" && return 0 || return 1
|
||||
}
|
||||
|
||||
|
||||
benchmark() {
|
||||
|
||||
export VLLM_LOGGING_LEVEL=DEBUG
|
||||
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
|
||||
|
||||
# compare chunked prefill with disaggregated prefill
|
||||
|
||||
results_folder="./results"
|
||||
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
dataset_name="sonnet"
|
||||
dataset_path="../sonnet_4x.txt"
|
||||
num_prompts=10
|
||||
qps=$1
|
||||
prefix_len=50
|
||||
input_len=2048
|
||||
output_len=$2
|
||||
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 vllm serve $model \
|
||||
--port 8100 \
|
||||
--max-model-len 10000 \
|
||||
--gpu-memory-utilization 0.6 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
|
||||
|
||||
|
||||
CUDA_VISIBLE_DEVICES=1 vllm serve $model \
|
||||
--port 8200 \
|
||||
--max-model-len 10000 \
|
||||
--gpu-memory-utilization 0.6 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
|
||||
|
||||
wait_for_server 8100
|
||||
wait_for_server 8200
|
||||
|
||||
# let the prefill instance finish prefill
|
||||
vllm bench serve \
|
||||
--backend vllm \
|
||||
--model $model \
|
||||
--dataset-name $dataset_name \
|
||||
--dataset-path $dataset_path \
|
||||
--sonnet-input-len $input_len \
|
||||
--sonnet-output-len "$output_len" \
|
||||
--sonnet-prefix-len $prefix_len \
|
||||
--num-prompts $num_prompts \
|
||||
--port 8100 \
|
||||
--save-result \
|
||||
--result-dir $results_folder \
|
||||
--result-filename disagg_prefill_tp1.json \
|
||||
--request-rate "inf"
|
||||
|
||||
|
||||
# send the request to decode.
|
||||
# The TTFT of this command will be the overhead of disagg prefill impl.
|
||||
vllm bench serve \
|
||||
--backend vllm \
|
||||
--model $model \
|
||||
--dataset-name $dataset_name \
|
||||
--dataset-path $dataset_path \
|
||||
--sonnet-input-len $input_len \
|
||||
--sonnet-output-len "$output_len" \
|
||||
--sonnet-prefix-len $prefix_len \
|
||||
--num-prompts $num_prompts \
|
||||
--port 8200 \
|
||||
--save-result \
|
||||
--result-dir $results_folder \
|
||||
--result-filename disagg_prefill_tp1_overhead.json \
|
||||
--request-rate "$qps"
|
||||
kill_gpu_processes
|
||||
|
||||
}
|
||||
|
||||
|
||||
main() {
|
||||
|
||||
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
||||
(which jq) || (apt-get -y install jq)
|
||||
(which socat) || (apt-get -y install socat)
|
||||
|
||||
pip install quart httpx datasets
|
||||
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
cd ..
|
||||
# create sonnet-4x.txt
|
||||
echo "" > sonnet_4x.txt
|
||||
for _ in {1..4}
|
||||
do
|
||||
cat sonnet.txt >> sonnet_4x.txt
|
||||
done
|
||||
cd disagg_benchmarks
|
||||
|
||||
rm -rf results
|
||||
mkdir results
|
||||
|
||||
default_qps=1
|
||||
default_output_len=1
|
||||
benchmark $default_qps $default_output_len
|
||||
|
||||
}
|
||||
|
||||
|
||||
main "$@"
|
||||
157
benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh
Normal file
157
benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh
Normal file
@@ -0,0 +1,157 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Requirement: 2x GPUs.
|
||||
|
||||
|
||||
# Model: meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||
# Query: 1024 input tokens, 6 output tokens, QPS 2/4/6/8, 100 requests
|
||||
# Resource: 2x GPU
|
||||
# Approaches:
|
||||
# 2. Chunked prefill: 2 vllm instance with tp=4, equivalent to 1 tp=4 instance with QPS 4
|
||||
# 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance
|
||||
# Prefilling instance: max_output_token=1
|
||||
# Decoding instance: force the input tokens be the same across requests to bypass prefilling
|
||||
|
||||
set -ex
|
||||
|
||||
kill_gpu_processes() {
|
||||
# kill all processes on GPU.
|
||||
pgrep pt_main_thread | xargs -r kill -9
|
||||
pgrep python3 | xargs -r kill -9
|
||||
# vLLM now names the process with VLLM prefix after https://github.com/vllm-project/vllm/pull/21445
|
||||
pgrep VLLM | xargs -r kill -9
|
||||
for port in 8000 8100 8200; do lsof -t -i:$port | xargs -r kill -9; done
|
||||
sleep 1
|
||||
}
|
||||
|
||||
wait_for_server() {
|
||||
# wait for vllm server to start
|
||||
# return 1 if vllm server crashes
|
||||
local port=$1
|
||||
timeout 1200 bash -c "
|
||||
until curl -s localhost:${port}/v1/completions > /dev/null; do
|
||||
sleep 1
|
||||
done" && return 0 || return 1
|
||||
}
|
||||
|
||||
|
||||
launch_chunked_prefill() {
|
||||
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
# disagg prefill
|
||||
CUDA_VISIBLE_DEVICES=0 vllm serve $model \
|
||||
--port 8100 \
|
||||
--max-model-len 10000 \
|
||||
--enable-chunked-prefill \
|
||||
--gpu-memory-utilization 0.6 &
|
||||
CUDA_VISIBLE_DEVICES=1 vllm serve $model \
|
||||
--port 8200 \
|
||||
--max-model-len 10000 \
|
||||
--enable-chunked-prefill \
|
||||
--gpu-memory-utilization 0.6 &
|
||||
wait_for_server 8100
|
||||
wait_for_server 8200
|
||||
python3 round_robin_proxy.py &
|
||||
sleep 1
|
||||
}
|
||||
|
||||
|
||||
launch_disagg_prefill() {
|
||||
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
# disagg prefill
|
||||
CUDA_VISIBLE_DEVICES=0 vllm serve $model \
|
||||
--port 8100 \
|
||||
--max-model-len 10000 \
|
||||
--gpu-memory-utilization 0.6 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
|
||||
|
||||
CUDA_VISIBLE_DEVICES=1 vllm serve $model \
|
||||
--port 8200 \
|
||||
--max-model-len 10000 \
|
||||
--gpu-memory-utilization 0.6 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
|
||||
|
||||
wait_for_server 8100
|
||||
wait_for_server 8200
|
||||
python3 disagg_prefill_proxy_server.py &
|
||||
sleep 1
|
||||
}
|
||||
|
||||
|
||||
benchmark() {
|
||||
results_folder="./results"
|
||||
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
dataset_name="sonnet"
|
||||
dataset_path="../sonnet_4x.txt"
|
||||
num_prompts=100
|
||||
qps=$1
|
||||
prefix_len=50
|
||||
input_len=1024
|
||||
output_len=$2
|
||||
tag=$3
|
||||
|
||||
vllm bench serve \
|
||||
--backend vllm \
|
||||
--model $model \
|
||||
--dataset-name $dataset_name \
|
||||
--dataset-path $dataset_path \
|
||||
--sonnet-input-len $input_len \
|
||||
--sonnet-output-len "$output_len" \
|
||||
--sonnet-prefix-len $prefix_len \
|
||||
--num-prompts $num_prompts \
|
||||
--port 8000 \
|
||||
--save-result \
|
||||
--result-dir $results_folder \
|
||||
--result-filename "$tag"-qps-"$qps".json \
|
||||
--request-rate "$qps"
|
||||
|
||||
sleep 2
|
||||
}
|
||||
|
||||
|
||||
main() {
|
||||
|
||||
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
||||
(which jq) || (apt-get -y install jq)
|
||||
(which socat) || (apt-get -y install socat)
|
||||
(which lsof) || (apt-get -y install lsof)
|
||||
|
||||
pip install quart httpx matplotlib aiohttp datasets
|
||||
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
cd ..
|
||||
# create sonnet-4x.txt so that we can sample 2048 tokens for input
|
||||
echo "" > sonnet_4x.txt
|
||||
for _ in {1..4}
|
||||
do
|
||||
cat sonnet.txt >> sonnet_4x.txt
|
||||
done
|
||||
cd disagg_benchmarks
|
||||
|
||||
rm -rf results
|
||||
mkdir results
|
||||
|
||||
default_output_len=6
|
||||
|
||||
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
|
||||
|
||||
launch_chunked_prefill
|
||||
for qps in 2 4 6 8; do
|
||||
benchmark $qps $default_output_len chunked_prefill
|
||||
done
|
||||
kill_gpu_processes
|
||||
|
||||
launch_disagg_prefill
|
||||
for qps in 2 4 6 8; do
|
||||
benchmark $qps $default_output_len disagg_prefill
|
||||
done
|
||||
kill_gpu_processes
|
||||
|
||||
python3 visualize_benchmark_results.py
|
||||
|
||||
}
|
||||
|
||||
|
||||
main "$@"
|
||||
260
benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py
Normal file
260
benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py
Normal file
@@ -0,0 +1,260 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
from quart import Quart, Response, make_response, request
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(description="vLLM P/D disaggregation proxy server")
|
||||
|
||||
# Add args
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=float,
|
||||
default=6 * 60 * 60,
|
||||
help="Timeout for backend service requests in seconds (default: 21600)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Port to run the server on (default: 8000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefill-url",
|
||||
type=str,
|
||||
default="http://localhost:8100",
|
||||
help="Prefill service base URL (protocol + host[:port])",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decode-url",
|
||||
type=str,
|
||||
default="http://localhost:8200",
|
||||
help="Decode service base URL (protocol + host[:port])",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="Hostname or IP used by KV transfer (default: localhost)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefill-kv-port",
|
||||
type=int,
|
||||
default=14579,
|
||||
help="Prefill KV port (default: 14579)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decode-kv-port",
|
||||
type=int,
|
||||
default=14580,
|
||||
help="Decode KV port (default: 14580)",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""parse command line arguments"""
|
||||
args = parse_args()
|
||||
|
||||
# Initialize configuration using command line parameters
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=args.timeout)
|
||||
PREFILL_SERVICE_URL = args.prefill_url
|
||||
DECODE_SERVICE_URL = args.decode_url
|
||||
PORT = args.port
|
||||
|
||||
PREFILL_KV_ADDR = f"{args.kv_host}:{args.prefill_kv_port}"
|
||||
DECODE_KV_ADDR = f"{args.kv_host}:{args.decode_kv_port}"
|
||||
|
||||
logger.info(
|
||||
"Proxy resolved KV addresses -> prefill: %s, decode: %s",
|
||||
PREFILL_KV_ADDR,
|
||||
DECODE_KV_ADDR,
|
||||
)
|
||||
|
||||
app = Quart(__name__)
|
||||
|
||||
# Attach the configuration object to the application instance so helper
|
||||
# coroutines can read the resolved backend URLs and timeouts without using
|
||||
# globals.
|
||||
app.config.update(
|
||||
{
|
||||
"AIOHTTP_TIMEOUT": AIOHTTP_TIMEOUT,
|
||||
"PREFILL_SERVICE_URL": PREFILL_SERVICE_URL,
|
||||
"DECODE_SERVICE_URL": DECODE_SERVICE_URL,
|
||||
"PREFILL_KV_ADDR": PREFILL_KV_ADDR,
|
||||
"DECODE_KV_ADDR": DECODE_KV_ADDR,
|
||||
}
|
||||
)
|
||||
|
||||
def _normalize_base_url(url: str) -> str:
|
||||
"""Remove any trailing slash so path joins behave predictably."""
|
||||
return url.rstrip("/")
|
||||
|
||||
def _get_host_port(url: str) -> str:
|
||||
"""Return the hostname:port portion for logging and KV headers."""
|
||||
parsed = urlparse(url)
|
||||
host = parsed.hostname or "localhost"
|
||||
port = parsed.port
|
||||
if port is None:
|
||||
port = 80 if parsed.scheme == "http" else 443
|
||||
return f"{host}:{port}"
|
||||
|
||||
PREFILL_BASE = _normalize_base_url(PREFILL_SERVICE_URL)
|
||||
DECODE_BASE = _normalize_base_url(DECODE_SERVICE_URL)
|
||||
KV_TARGET = _get_host_port(DECODE_SERVICE_URL)
|
||||
|
||||
def _build_headers(request_id: str) -> dict[str, str]:
|
||||
"""Construct the headers expected by vLLM's P2P disagg connector."""
|
||||
headers: dict[str, str] = {"X-Request-Id": request_id, "X-KV-Target": KV_TARGET}
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
||||
async def _run_prefill(
|
||||
request_path: str,
|
||||
payload: dict,
|
||||
headers: dict[str, str],
|
||||
request_id: str,
|
||||
):
|
||||
url = f"{PREFILL_BASE}{request_path}"
|
||||
start_ts = time.perf_counter()
|
||||
logger.info("[prefill] start request_id=%s url=%s", request_id, url)
|
||||
try:
|
||||
async with (
|
||||
aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
|
||||
session.post(url=url, json=payload, headers=headers) as resp,
|
||||
):
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
raise RuntimeError(
|
||||
f"Prefill backend error {resp.status}: {error_text}"
|
||||
)
|
||||
await resp.read()
|
||||
logger.info(
|
||||
"[prefill] done request_id=%s status=%s elapsed=%.2fs",
|
||||
request_id,
|
||||
resp.status,
|
||||
time.perf_counter() - start_ts,
|
||||
)
|
||||
except asyncio.TimeoutError as exc:
|
||||
raise RuntimeError(f"Prefill service timeout at {url}") from exc
|
||||
except aiohttp.ClientError as exc:
|
||||
raise RuntimeError(f"Prefill service unavailable at {url}") from exc
|
||||
|
||||
async def _stream_decode(
|
||||
request_path: str,
|
||||
payload: dict,
|
||||
headers: dict[str, str],
|
||||
request_id: str,
|
||||
):
|
||||
url = f"{DECODE_BASE}{request_path}"
|
||||
# Stream tokens from the decode service once the prefill stage has
|
||||
# materialized KV caches on the target workers.
|
||||
logger.info("[decode] start request_id=%s url=%s", request_id, url)
|
||||
try:
|
||||
async with (
|
||||
aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
|
||||
session.post(url=url, json=payload, headers=headers) as resp,
|
||||
):
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
logger.error(
|
||||
"Decode backend error %s - %s", resp.status, error_text
|
||||
)
|
||||
err_msg = (
|
||||
'{"error": "Decode backend error ' + str(resp.status) + '"}'
|
||||
)
|
||||
yield err_msg.encode()
|
||||
return
|
||||
logger.info(
|
||||
"[decode] streaming response request_id=%s status=%s",
|
||||
request_id,
|
||||
resp.status,
|
||||
)
|
||||
async for chunk_bytes in resp.content.iter_chunked(1024):
|
||||
yield chunk_bytes
|
||||
logger.info("[decode] finished streaming request_id=%s", request_id)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Decode service timeout at %s", url)
|
||||
yield b'{"error": "Decode service timeout"}'
|
||||
except aiohttp.ClientError as exc:
|
||||
logger.error("Decode service error at %s: %s", url, exc)
|
||||
yield b'{"error": "Decode service unavailable"}'
|
||||
|
||||
async def process_request():
|
||||
"""Process a single request through prefill and decode stages"""
|
||||
try:
|
||||
original_request_data = await request.get_json()
|
||||
|
||||
# Create prefill request (max_tokens=1)
|
||||
prefill_request = original_request_data.copy()
|
||||
prefill_request["max_tokens"] = 1
|
||||
if "max_completion_tokens" in prefill_request:
|
||||
prefill_request["max_completion_tokens"] = 1
|
||||
|
||||
# Execute prefill stage
|
||||
# The request id encodes both KV socket addresses so the backend can
|
||||
# shuttle tensors directly via NCCL once the prefill response
|
||||
# completes.
|
||||
request_id = (
|
||||
f"___prefill_addr_{PREFILL_KV_ADDR}___decode_addr_"
|
||||
f"{DECODE_KV_ADDR}_{uuid.uuid4().hex}"
|
||||
)
|
||||
|
||||
headers = _build_headers(request_id)
|
||||
await _run_prefill(request.path, prefill_request, headers, request_id)
|
||||
|
||||
# Execute decode stage and stream response
|
||||
# Pass the unmodified user request so the decode phase can continue
|
||||
# sampling with the already-populated KV cache.
|
||||
generator = _stream_decode(
|
||||
request.path, original_request_data, headers, request_id
|
||||
)
|
||||
response = await make_response(generator)
|
||||
response.timeout = None # Disable timeout for streaming response
|
||||
return response
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error processing request")
|
||||
return Response(
|
||||
response=b'{"error": "Internal server error"}',
|
||||
status=500,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
@app.route("/v1/completions", methods=["POST"])
|
||||
async def handle_request():
|
||||
"""Handle incoming API requests with concurrency and rate limiting"""
|
||||
try:
|
||||
return await process_request()
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("Request cancelled")
|
||||
return Response(
|
||||
response=b'{"error": "Request cancelled"}',
|
||||
status=503,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
# Start the Quart server with host can be set to 0.0.0.0
|
||||
app.run(port=PORT)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
45
benchmarks/disagg_benchmarks/rate_limiter.py
Normal file
45
benchmarks/disagg_benchmarks/rate_limiter.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Token bucket rate limiter implementation"""
|
||||
|
||||
def __init__(self, rate_limit):
|
||||
self.rate_limit = rate_limit # Requests per second
|
||||
self.num_available_tokens = rate_limit # Available tokens
|
||||
self.last_refill = time.monotonic() # Last token refill time
|
||||
self.lock = asyncio.Lock() # Synchronization lock
|
||||
|
||||
async def acquire(self):
|
||||
"""Acquire a token from the rate limiter"""
|
||||
while True:
|
||||
async with self.lock:
|
||||
current_time = time.monotonic()
|
||||
elapsed = current_time - self.last_refill
|
||||
|
||||
# Refill num_available_tokens if more than 1 second has passed
|
||||
if elapsed > 1.0:
|
||||
self.num_available_tokens = self.rate_limit
|
||||
self.last_refill = current_time
|
||||
|
||||
# Check if num_available_tokens are available
|
||||
if self.num_available_tokens > 0:
|
||||
self.num_available_tokens -= 1
|
||||
return True
|
||||
|
||||
# Calculate wait time if no num_available_tokens available
|
||||
wait_time = 1.0 - elapsed
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Enter async context manager - acquire token"""
|
||||
await self.acquire()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
"""Exit async context manager - no cleanup needed"""
|
||||
pass
|
||||
39
benchmarks/disagg_benchmarks/request_queue.py
Normal file
39
benchmarks/disagg_benchmarks/request_queue.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
|
||||
|
||||
class RequestQueue:
|
||||
"""Request queue manager with concurrency control"""
|
||||
|
||||
def __init__(self, max_concurrent, max_queue_size):
|
||||
# Maximum concurrent requests
|
||||
self.max_concurrent = max_concurrent
|
||||
self.max_queue_size = max_queue_size # Maximum queue size
|
||||
# Concurrency control
|
||||
self.semaphore = asyncio.Semaphore(max_concurrent)
|
||||
self.queue = deque() # Request queue
|
||||
self.queue_size = 0 # Current queue size
|
||||
self.lock = asyncio.Lock() # Sync queue Lock
|
||||
|
||||
async def enqueue(self, task):
|
||||
"""Add a request task to the queue"""
|
||||
async with self.lock:
|
||||
if self.queue_size >= self.max_queue_size:
|
||||
return False
|
||||
|
||||
self.queue.append(task)
|
||||
self.queue_size += 1
|
||||
return True
|
||||
|
||||
async def process(self):
|
||||
"""Process queued requests using semaphore for concurrency control"""
|
||||
while True:
|
||||
if self.queue:
|
||||
async with self.semaphore, self.lock:
|
||||
task = self.queue.popleft()
|
||||
self.queue_size -= 1
|
||||
await task
|
||||
await asyncio.sleep(0.01) # Yield control to event loop
|
||||
63
benchmarks/disagg_benchmarks/round_robin_proxy.py
Normal file
63
benchmarks/disagg_benchmarks/round_robin_proxy.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import itertools
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
|
||||
|
||||
class RoundRobinProxy:
|
||||
def __init__(self, target_ports):
|
||||
self.target_ports = target_ports
|
||||
self.port_cycle = itertools.cycle(self.target_ports)
|
||||
|
||||
async def handle_request(self, request):
|
||||
target_port = next(self.port_cycle)
|
||||
target_url = f"http://localhost:{target_port}{request.path_qs}"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
# Forward the request
|
||||
async with session.request(
|
||||
method=request.method,
|
||||
url=target_url,
|
||||
headers=request.headers,
|
||||
data=request.content,
|
||||
) as response:
|
||||
# Start sending the response
|
||||
resp = web.StreamResponse(
|
||||
status=response.status, headers=response.headers
|
||||
)
|
||||
await resp.prepare(request)
|
||||
|
||||
# Stream the response content
|
||||
async for chunk in response.content.iter_any():
|
||||
await resp.write(chunk)
|
||||
|
||||
await resp.write_eof()
|
||||
return resp
|
||||
|
||||
except Exception as e:
|
||||
return web.Response(text=f"Error: {str(e)}", status=500)
|
||||
|
||||
|
||||
async def main():
|
||||
proxy = RoundRobinProxy([8100, 8200])
|
||||
app = web.Application()
|
||||
app.router.add_route("*", "/{path:.*}", proxy.handle_request)
|
||||
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "localhost", 8000)
|
||||
await site.start()
|
||||
|
||||
print("Proxy server started on http://localhost:8000")
|
||||
|
||||
# Keep the server running
|
||||
await asyncio.Event().wait()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
47
benchmarks/disagg_benchmarks/visualize_benchmark_results.py
Normal file
47
benchmarks/disagg_benchmarks/visualize_benchmark_results.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
if __name__ == "__main__":
|
||||
data = []
|
||||
for name in ["disagg_prefill", "chunked_prefill"]:
|
||||
for qps in [2, 4, 6, 8]:
|
||||
with open(f"results/{name}-qps-{qps}.json") as f:
|
||||
x = json.load(f)
|
||||
x["name"] = name
|
||||
x["qps"] = qps
|
||||
data.append(x)
|
||||
|
||||
df = pd.DataFrame.from_dict(data)
|
||||
dis_df = df[df["name"] == "disagg_prefill"]
|
||||
chu_df = df[df["name"] == "chunked_prefill"]
|
||||
|
||||
plt.style.use("bmh")
|
||||
plt.rcParams["font.size"] = 20
|
||||
|
||||
for key in [
|
||||
"mean_ttft_ms",
|
||||
"median_ttft_ms",
|
||||
"p99_ttft_ms",
|
||||
"mean_itl_ms",
|
||||
"median_itl_ms",
|
||||
"p99_itl_ms",
|
||||
]:
|
||||
fig, ax = plt.subplots(figsize=(11, 7))
|
||||
plt.plot(
|
||||
dis_df["qps"], dis_df[key], label="disagg_prefill", marker="o", linewidth=4
|
||||
)
|
||||
plt.plot(
|
||||
chu_df["qps"], chu_df[key], label="chunked_prefill", marker="o", linewidth=4
|
||||
)
|
||||
ax.legend()
|
||||
|
||||
ax.set_xlabel("QPS")
|
||||
ax.set_ylabel(key)
|
||||
ax.set_ylim(bottom=0)
|
||||
fig.savefig(f"results/{key}.png")
|
||||
plt.close(fig)
|
||||
Reference in New Issue
Block a user