Files
xc-llm-ascend/examples/disaggregated_prefill_v1/toy_proxy_server.py
Pleaplusone df0ec55162 Disaggregate prefill for kv cache register style (#950)
### What this PR does / why we need it?
This PR adopt `LLMDataDist` for kv cache register and `pull_blocks`
style disaggregate prefill implementation. The interface implementation
mainly follows the design of NIXL PR
https://github.com/vllm-project/vllm/pull/17751/files#diff-7eaad0b7dee0626bf29d10081b0f0c5e3ea15a4af97e7b182a4e0d35f8346953
.

This PR can be test with the following step:
- Generate the rank table for all machine.
- execute`toy_proxy.py` to launch the disaggregate prefill proxy server,
specify the prefill ip, port and the decode ip, port
- Run the prefill server and decode server.
- send the request to the disaggregate prefill proxy

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.9.2
- vLLM main:
8d0a01a5f2

---------

Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
Signed-off-by: machenglong <machenglong_yewu@cmss.chinamobile.com>
Signed-off-by: liziyu179 <3475441767@qq.com>
Signed-off-by: underfitc <hucong24@huawei.com>
Signed-off-by: zouyida2052 <zouyida@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: underfituu <hzhucong@163.com>
Co-authored-by: machenglong <machenglong_yewu@cmss.chinamobile.com>
Co-authored-by: liziyu179 <3475441767@qq.com>
Co-authored-by: underfitc <hucong24@huawei.com>
Co-authored-by: zouyida2052 <zouyida@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
Co-authored-by: underfituu <hzhucong@163.com>
2025-07-26 17:15:47 +08:00

275 lines
8.7 KiB
Python

# Adapted from https://github.com/vllm-project/vllm/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py
# SPDX-License-Identifier: Apache-2.0
import argparse
import itertools
import os
import uuid
from contextlib import asynccontextmanager
import httpx
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from vllm.logger import init_logger
logger = init_logger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Lifespan context manager to handle startup and shutdown events.
"""
# Startup: Initialize client pools for prefiller and decoder services
app.state.prefill_clients = []
app.state.decode_clients = []
limit = httpx.Limits(max_connections=100000,
max_keepalive_connections=100000)
# Create prefill clients
for i, (host, port) in enumerate(global_args.prefiller_instances):
prefiller_base_url = f'http://{host}:{port}/v1'
app.state.prefill_clients.append({
'client':
httpx.AsyncClient(timeout=None,
base_url=prefiller_base_url,
limits=limit),
'host':
host,
'port':
port,
'id':
i
})
# Create decode clients
for i, (host, port) in enumerate(global_args.decoder_instances):
decoder_base_url = f'http://{host}:{port}/v1'
app.state.decode_clients.append({
'client':
httpx.AsyncClient(timeout=None,
base_url=decoder_base_url,
limits=limit),
'host':
host,
'port':
port,
'id':
i
})
# Initialize round-robin iterators
app.state.prefill_iterator = itertools.cycle(
range(len(app.state.prefill_clients)))
app.state.decode_iterator = itertools.cycle(
range(len(app.state.decode_clients)))
print(f"Initialized {len(app.state.prefill_clients)} prefill clients "
f"and {len(app.state.decode_clients)} decode clients.")
yield
# Shutdown: Close all clients
for client_info in app.state.prefill_clients:
await client_info['client'].aclose()
for client_info in app.state.decode_clients:
await client_info['client'].aclose()
# Update FastAPI app initialization to use lifespan
app = FastAPI(lifespan=lifespan)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--host", type=str, default="localhost")
# For prefiller instances
parser.add_argument("--prefiller-hosts",
"--prefiller-host",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--prefiller-ports",
"--prefiller-port",
type=int,
nargs="+",
default=[8100])
# For decoder instances
parser.add_argument("--decoder-hosts",
"--decoder-host",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--decoder-ports",
"--decoder-port",
type=int,
nargs="+",
default=[8200])
args = parser.parse_args()
# Validate and pair hosts with ports
if len(args.prefiller_hosts) != len(args.prefiller_ports):
raise ValueError(
"Number of prefiller hosts must match number of prefiller ports")
if len(args.decoder_hosts) != len(args.decoder_ports):
raise ValueError(
"Number of decoder hosts must match number of decoder ports")
# Create tuples of (host, port) for each service type
args.prefiller_instances = list(
zip(args.prefiller_hosts, args.prefiller_ports))
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
return args
def get_next_client(app, service_type: str):
"""
Get the next client in round-robin fashion.
Args:
app: The FastAPI app instance
service_type: Either 'prefill' or 'decode'
Returns:
The next client to use
"""
if service_type == 'prefill':
client_idx = next(app.state.prefill_iterator)
return app.state.prefill_clients[client_idx]
elif service_type == 'decode':
client_idx = next(app.state.decode_iterator)
return app.state.decode_clients[client_idx]
else:
raise ValueError(f"Unknown service type: {service_type}")
async def send_request_to_service(client_info: dict, endpoint: str,
req_data: dict, request_id: str):
"""
Send a request to a service using a client from the pool.
"""
req_data = req_data.copy()
req_data['kv_transfer_params'] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": None,
"remote_port": None
}
req_data["stream"] = False
req_data["max_tokens"] = 1
if "stream_options" in req_data:
del req_data["stream_options"]
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
response = await client_info['client'].post(endpoint,
json=req_data,
headers=headers)
response.raise_for_status()
return response
async def stream_service_response(client_info: dict, endpoint: str,
req_data: dict, request_id: str):
"""
Asynchronously stream response from a service using a client from the pool.
"""
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
async with client_info['client'].stream("POST",
endpoint,
json=req_data,
headers=headers) as response:
response.raise_for_status()
async for chunk in response.aiter_bytes():
yield chunk
async def _handle_completions(api: str, request: Request):
try:
req_data = await request.json()
request_id = str(uuid.uuid4())
# Get the next prefill client in round-robin fashion
prefill_client_info = get_next_client(request.app, 'prefill')
# Send request to prefill service
response = await send_request_to_service(prefill_client_info, api,
req_data, request_id)
# Extract the needed fields
response_json = response.json()
kv_transfer_params = response_json.get('kv_transfer_params', {})
if kv_transfer_params:
req_data["kv_transfer_params"] = kv_transfer_params
# Get the next decode client in round-robin fashion
decode_client_info = get_next_client(request.app, 'decode')
logger.debug("Using %s %s", prefill_client_info, decode_client_info)
# Stream response from decode service
async def generate_stream():
async for chunk in stream_service_response(decode_client_info,
api,
req_data,
request_id=request_id):
yield chunk
return StreamingResponse(generate_stream(),
media_type="application/json")
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server"
f" - {api} endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
@app.post("/v1/completions")
async def handle_completions(request: Request):
return await _handle_completions("/completions", request)
@app.post("/v1/chat/completions")
async def handle_chat_completions(request: Request):
return await _handle_completions("/chat/completions", request)
@app.get("/healthcheck")
async def healthcheck():
"""Simple endpoint to check if the server is running."""
return {
"status": "ok",
"prefill_instances": len(app.state.prefill_clients),
"decode_instances": len(app.state.decode_clients)
}
if __name__ == '__main__':
global global_args
global_args = parse_args()
import uvicorn
uvicorn.run(app, host=global_args.host, port=global_args.port)