Disaggregate prefill for kv cache register style (#950)

### What this PR does / why we need it?
This PR adopt `LLMDataDist` for kv cache register and `pull_blocks`
style disaggregate prefill implementation. The interface implementation
mainly follows the design of NIXL PR
https://github.com/vllm-project/vllm/pull/17751/files#diff-7eaad0b7dee0626bf29d10081b0f0c5e3ea15a4af97e7b182a4e0d35f8346953
.

This PR can be test with the following step:
- Generate the rank table for all machine.
- execute`toy_proxy.py` to launch the disaggregate prefill proxy server,
specify the prefill ip, port and the decode ip, port
- Run the prefill server and decode server.
- send the request to the disaggregate prefill proxy

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.9.2
- vLLM main:
8d0a01a5f2

---------

Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
Signed-off-by: machenglong <machenglong_yewu@cmss.chinamobile.com>
Signed-off-by: liziyu179 <3475441767@qq.com>
Signed-off-by: underfitc <hucong24@huawei.com>
Signed-off-by: zouyida2052 <zouyida@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: underfituu <hzhucong@163.com>
Co-authored-by: machenglong <machenglong_yewu@cmss.chinamobile.com>
Co-authored-by: liziyu179 <3475441767@qq.com>
Co-authored-by: underfitc <hucong24@huawei.com>
Co-authored-by: zouyida2052 <zouyida@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
Co-authored-by: underfituu <hzhucong@163.com>
This commit is contained in:
Pleaplusone
2025-07-26 17:15:47 +08:00
committed by GitHub
parent 17a430f7b8
commit df0ec55162
28 changed files with 2833 additions and 144 deletions

View File

@@ -42,8 +42,7 @@ jobs:
strategy:
matrix:
vllm_verison: [
# revert me when V1 disaggregation prefill is merged in main
# main,
main,
v0.9.1
]
name: vLLM Ascend prefilling decoding disaggregation test
@@ -107,6 +106,6 @@ jobs:
pip install -r requirements-dev.txt
pip install -v -e .
- name: Run vllm-project/vllm-ascend PD Disaggregation test
- name: Run vllm-project/vllm-ascend PD Disaggregation edge test
run: |
pytest -sv tests/e2e/pd_disaggreate/test_pd_e2e.py
bash tests/e2e/pd_disaggreate/run_edge_case_test.sh

View File

@@ -0,0 +1,230 @@
# Disaggregated Prefill-Decode Deployment Guide
## Overview
This demo document provides instructions for running a disaggregated vLLM-ascend service with separate prefill and decode stages across 4 nodes, uses 16 Ascend NPUs for two prefill nodes (P1/P2) and 16 Ascend NPUS for two decode nodes (D1/D2).
## Prerequisites
- Ascend NPU environment with vLLM 0.9.1 installed
- Network interfaces configured for distributed communication (eg: eth0)
- Model weights located at `/data01/deepseek_r1_w8a8_zhw`
## Rank table generation
The rank table is a JSON file that specifies the mapping of Ascend NPU ranks to nodes. The following command generates a rank table for all nodes with 16 cards prefill and 16 cards decode:
Run the following command on every node to generate the rank table:
```shell
cd vllm-ascend/examples/disaggregate_prefill_v1/
bash gen_ranktable.sh --ips 172.19.32.175 172.19.241.49 172.19.123.51 172.19.190.36 \
--npus-per-node 8 --network-card-name enp189s0f0 --prefill-device-cnt 16 --decode-device-cnt 16
```
Rank table will generated at `/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json`
## Start disaggregated vLLM-ascend service
Execution Sequence
- 4 configured node ip are: 172.19.32.175 172.19.241.49 172.19.123.51 172.19.190.36
- Start Prefill on Node 1 (P1)
- Start Prefill on Node 2 (P2)
- Start Decode on Node 1 (D1)
- Start Decode on Node 2 (D2)
- Start proxy server on Node1
* Run prefill server P1 on first node
```shell
export HCCL_IF_IP=172.19.32.175 # node ip
export GLOO_SOCKET_IFNAME="eth0" # network card name
export TP_SOCKET_IFNAME="eth0"
export HCCL_SOCKET_IFNAME="eth0"
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=100
export VLLM_USE_V1=1
vllm serve /data01/deepseek_r1_w8a8_zhw \
--host 0.0.0.0 \
--port 20002 \
--data-parallel-size 2 \
--data-parallel-size-local 1 \
--api-server-count 2 \
--data-parallel-address 172.19.32.175 \
--data-parallel-rpc-port 13356 \
--tensor-parallel-size 8 \
--no-enable-prefix-caching \
--seed 1024 \
--served-model-name deepseek \
--max-model-len 6144 \
--max-num-batched-tokens 6144 \
--trust-remote-code \
--enforce-eager \
--gpu-memory-utilization 0.9 \
--kv-transfer-config \
'{"kv_connector": "LLMDataDistCMgrConnector",
"kv_buffer_device": "npu",
"kv_role": "kv_producer",
"kv_parallel_size": 1,
"kv_port": "20001",
"engine_id": "0",
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
}' \
--additional-config \
'{"torchair_graph_config": {"enabled": false, "enable_multistream_shared_expert": false}, "ascend_scheduler_config":{"enabled":false}}'
```
* Run prefill server P2 on second node
```shell
export HCCL_IF_IP=172.19.241.49
export GLOO_SOCKET_IFNAME="eth0"
export TP_SOCKET_IFNAME="eth0"
export HCCL_SOCKET_IFNAME="eth0"
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=100
export VLLM_USE_V1=1
vllm serve /data01/deepseek_r1_w8a8_zhw \
--host 0.0.0.0 \
--port 20002 \
--headless \
--data-parallel-size 2 \
--data-parallel-start-rank 1 \
--data-parallel-size-local 1 \
--data-parallel-address 172.19.32.175 \
--data-parallel-rpc-port 13356 \
--tensor-parallel-size 8 \
--no-enable-prefix-caching \
--seed 1024 \
--served-model-name deepseek \
--max-model-len 6144 \
--max-num-batched-tokens 6144 \
--trust-remote-code \
--enforce-eager \
--gpu-memory-utilization 0.9 \
--kv-transfer-config \
'{"kv_connector": "LLMDataDistCMgrConnector",
"kv_buffer_device": "npu",
"kv_role": "kv_producer",
"kv_parallel_size": 1,
"kv_port": "20001",
"engine_id": "0",
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
}' \
--additional-config \
'{"torchair_graph_config": {"enabled": false, "enable_multistream_shared_expert": false}, "ascend_scheduler_config":{"enabled":false}}'
```
* Run decode server d1 on third node
```shell
export HCCL_IF_IP=172.19.123.51
export GLOO_SOCKET_IFNAME="eth0"
export TP_SOCKET_IFNAME="eth0"
export HCCL_SOCKET_IFNAME="eth0"
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=100
export VLLM_USE_V1=1
vllm serve /data01/deepseek_r1_w8a8_zhw \
--host 0.0.0.0 \
--port 20002 \
--data-parallel-size 2 \
--data-parallel-size-local 1 \
--api-server-count 2 \
--data-parallel-address 172.19.123.51 \
--data-parallel-rpc-port 13356 \
--tensor-parallel-size 8 \
--no-enable-prefix-caching \
--seed 1024 \
--served-model-name deepseek \
--max-model-len 6144 \
--max-num-batched-tokens 6144 \
--trust-remote-code \
--enforce-eager \
--gpu-memory-utilization 0.9 \
--kv-transfer-config \
'{"kv_connector": "LLMDataDistCMgrConnector",
"kv_buffer_device": "npu",
"kv_role": "kv_consumer",
"kv_parallel_size": 1,
"kv_port": "20001",
"engine_id": "0",
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
}' \
--additional-config \
'{"torchair_graph_config": {"enabled": false, "enable_multistream_shared_expert": false}, "ascend_scheduler_config":{"enabled":false}}'
```
* Run decode server d2 on last node
```shell
export HCCL_IF_IP=172.19.190.36
export GLOO_SOCKET_IFNAME="eth0"
export TP_SOCKET_IFNAME="eth0"
export HCCL_SOCKET_IFNAME="eth0"
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=100
export VLLM_USE_V1=1
vllm serve /data01/deepseek_r1_w8a8_zhw \
--host 0.0.0.0 \
--port 20002 \
--headless \
--data-parallel-size 2 \
--data-parallel-start-rank 1 \
--data-parallel-size-local 1 \
--data-parallel-address 172.19.123.51 \
--data-parallel-rpc-port 13356 \
--tensor-parallel-size 8 \
--no-enable-prefix-caching \
--seed 1024 \
--served-model-name deepseek \
--max-model-len 6144 \
--max-num-batched-tokens 6144 \
--trust-remote-code \
--enforce-eager \
--gpu-memory-utilization 0.9 \
--kv-transfer-config \
'{"kv_connector": "LLMDataDistCMgrConnector",
"kv_buffer_device": "npu",
"kv_role": "kv_consumer",
"kv_parallel_size": 1,
"kv_port": "20001",
"engine_id": "0",
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
}' \
--additional-config \
'{"torchair_graph_config": {"enabled": false, "enable_multistream_shared_expert": false}, "ascend_scheduler_config":{"enabled":false}}'
```
* Run proxy server on the first node
```shell
cd /vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1
python toy_proxy_server.py --host 172.19.32.175 --port 1025 --prefiller-hosts 172.19.241.49 --prefiller-port 20002 --decoder-hosts 172.19.123.51 --decoder-ports 20002
```
* Verification
Check service health using the proxy server endpoint:
```shell
curl http://localhost:1025/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "deepseek",
"prompt": "Who are you?",
"max_tokens": 100,
"temperature": 0
}'
```
* Performance
Test performance with vllm benchmark
```shell
cd /vllm-workspace/vllm/benchmarks
python3 benchmark_serving.py \
--backend vllm \
--dataset-name random \
--random-input-len 4096 \
--random-output-len 1536 \
--num-prompts 256 \
--ignore-eos \
--model deepseek \
--tokenizer /data01/deepseek_r1_w8a8_zhw \
--host localhost \
--port 8000 \
--endpoint /v1/completions \
--max-concurrency 4 \
--request-rate 4
```

View File

@@ -0,0 +1,120 @@
import argparse
import json
import os
import torch.distributed as dist
from vllm_ascend.soc_info import NPUSocInfo
parser = argparse.ArgumentParser(
description="Arguments of rank table generator", )
parser.add_argument("--local-host", type=str, required=True, help="local ip")
parser.add_argument("--prefill-device-cnt",
type=int,
required=True,
help="number of prefill devices")
parser.add_argument("--decode-device-cnt",
type=int,
required=True,
help="number of decode devices")
args = parser.parse_args()
local_host = args.local_host
prefill_device_cnt = args.prefill_device_cnt
decode_device_cnt = args.decode_device_cnt
print("enter py")
hccn_tool_path = os.environ.get("HCCN_TOOL_PATH",
"/usr/local/Ascend/driver/tools/hccn_tool")
master_addr = os.environ.get("MASTER_ADDR")
master_port = os.environ.get("MASTER_PORT")
rank = os.environ.get("RANK")
local_rank = os.environ.get("LOCAL_RANK")
# This variable is set by torchrun,
# and is different from WORLD_SIZE in gen_rank_table.sh.
world_size = os.environ.get("WORLD_SIZE")
soc_info = NPUSocInfo()
def get_cmd_stdout(cmd):
import subprocess
return subprocess.run(cmd, capture_output=True,
shell=True).stdout.decode("utf-8").strip()
print(f"local_host: {local_host}")
print("gen ranktable.json")
num_cards = get_cmd_stdout("npu-smi info -l | grep \"Total Count\"").split(
":")[1].strip()
num_cards = int(num_cards)
chips_per_card = get_cmd_stdout("npu-smi info -l | grep \"Chip Count\"").split(
"\n")[0].split(":")[1].strip()
chips_per_card = int(chips_per_card)
# generate local device list for local rank 0, and gather it to all ranks
local_device_list: list[dict[str, str]] = list()
if local_rank == "0":
super_pod_id = "0"
for card_id in range(num_cards):
for chip_id in range(chips_per_card):
device_id = card_id * chips_per_card + chip_id
if soc_info.is_a3:
device_ip = get_cmd_stdout(
f"{hccn_tool_path} -i {device_id} -vnic -g | grep ipaddr"
).split(":")[1].strip()
super_device_id = get_cmd_stdout(
f"npu-smi info -t spod-info -i {card_id} -c {chip_id} | grep SDID"
).split(":")[1].strip()
super_pod_id = get_cmd_stdout(
f"npu-smi info -t spod-info -i {card_id} -c {chip_id} | grep \"Super Pod ID\""
).split(":")[1].strip()
else:
device_ip = get_cmd_stdout(
f"{hccn_tool_path} -i {device_id} -ip -g | grep ipaddr"
).split(":")[1].strip()
device_info = {
"server_id": local_host,
"device_id": str(device_id),
"device_ip": str(device_ip),
}
if soc_info.is_a3:
device_info.update({
"super_pod_id": str(super_pod_id),
"super_device_id": str(super_device_id)
})
local_device_list.append(device_info)
dist.init_process_group(backend=dist.Backend.GLOO)
global_device_list = [None] * dist.get_world_size()
dist.all_gather_object(global_device_list, local_device_list)
global_device_list = [
device_info for device_list in global_device_list
for device_info in device_list # type: ignore[attr-defined]
]
cnt = 1
for device_info in global_device_list: # type: ignore[assignment]
device_info["cluster_id"] = str(cnt)
cnt += 1
assert (prefill_device_cnt + decode_device_cnt) <= len(global_device_list), \
"prefill_device_cnt + decode_device_cnt must be less than or equal to number of all devices in cluster"
ranktable = {
"version":
"1.2",
"server_count":
str(world_size),
"prefill_device_list":
global_device_list[:prefill_device_cnt],
"decode_device_list":
global_device_list[prefill_device_cnt:prefill_device_cnt +
decode_device_cnt],
"status":
"completed"
}
if local_rank == '0':
with open("ranktable.json", "w") as f:
json.dump(ranktable, f, indent=4)
print("gen ranktable.json done")

View File

@@ -0,0 +1,79 @@
#!/bin/bash
source /usr/local/Ascend/ascend-toolkit/set_env.sh
export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/op_api/lib/:${LD_LIBRARY_PATH}
NPUS_PER_NODE=8
while [[ $# -gt 0 ]]; do
case "$1" in
--ips)
shift
while [[ $# -gt 0 && ! "$1" == --* ]]; do
IPs+=("$1")
shift
done
;;
--npus-per-node)
shift
NPUS_PER_NODE="$1"
shift
;;
--network-card-name)
shift
NETWORK_CARD_NAME="$1"
shift
;;
--prefill-device-cnt)
shift
PREFILL_DEVICE_CNT="$1"
shift
;;
--decode-device-cnt)
shift
DECODE_DEVICE_CNT="$1"
shift
;;
esac
done
LOCAL_HOSTS=($(hostname -I))
LOCAL_HOST="127.0.0.1"
MASTER_ADDR=${IPs[0]}
MASTER_PORT=6657
NNODES=${#IPs[@]}
NODE_RANK="8"
for i in "${!IPs[@]}"; do
ip="${IPs[$i]}"
for local_host in "${LOCAL_HOSTS[@]}"; do
if [[ "$local_host" == "$ip" ]]; then
LOCAL_HOST=$local_host
NODE_RANK=$i
break 2
fi
done
done
if [[ $NODE_RANK == "" ]];then
echo "[Error] para \"NODE_RANK\" must be defined"
exit 1
fi
WORLD_SIZE=$(($NPUS_PER_NODE * $NNODES))
RANKSTART=`expr $NPUS_PER_NODE \* $NODE_RANK`
echo "========>param:"
echo "LOCAL_HOST": $LOCAL_HOST
echo "WORLD_SIZE: " $WORLD_SIZE
echo "RANKSTART": $RANKSTART
echo "NNODES": $NNODES
echo "NODE_RANK": $NODE_RANK
echo "==============="
if [[ -n "${GEN_RANKTABLE}" || ! -e ${PWD}/ranktable.json ]]; then
GLOO_SOCKET_IFNAME=$NETWORK_CARD_NAME torchrun \
--nproc_per_node 1 \
--nnodes ${NNODES} \
--node_rank ${NODE_RANK} \
--master_addr ${MASTER_ADDR} \
--master_port ${MASTER_PORT} \
gen_ranktable.py --local-host $LOCAL_HOST --prefill-device-cnt $PREFILL_DEVICE_CNT --decode-device-cnt $DECODE_DEVICE_CNT
fi

View File

@@ -0,0 +1,32 @@
export HCCL_IF_IP=141.61.39.117
export GLOO_SOCKET_IFNAME="enp48s3u1u1"
export TP_SOCKET_IFNAME="enp48s3u1u1"
export HCCL_SOCKET_IFNAME="enp48s3u1u1"
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=path-to-rank-table
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=100
export VLLM_USE_V1=1
vllm serve model_path \
--host 0.0.0.0 \
--port 20002 \
--tensor-parallel-size 1\
--seed 1024 \
--served-model-name dsv3 \
--max-model-len 2000 \
---max-num-batched-tokens 2000 \
--trust-remote-code \
--gpu-memory-utilization 0.9 \
--kv-transfer-config \
'{"kv_connector": "LLMDataDistCMgrConnector",
"kv_buffer_device": "npu",
"kv_role": "kv_consumer",
"kv_parallel_size": 1,
"kv_port": "20001",
"engine_id": 0,
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_connector_v1_a3"
}' \
--additional-config \
'{"enable_graph_mode": "True"}'\

View File

@@ -0,0 +1,275 @@
# Adapted from https://github.com/vllm-project/vllm/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py
# SPDX-License-Identifier: Apache-2.0
import argparse
import itertools
import os
import uuid
from contextlib import asynccontextmanager
import httpx
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from vllm.logger import init_logger
logger = init_logger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Lifespan context manager to handle startup and shutdown events.
"""
# Startup: Initialize client pools for prefiller and decoder services
app.state.prefill_clients = []
app.state.decode_clients = []
limit = httpx.Limits(max_connections=100000,
max_keepalive_connections=100000)
# Create prefill clients
for i, (host, port) in enumerate(global_args.prefiller_instances):
prefiller_base_url = f'http://{host}:{port}/v1'
app.state.prefill_clients.append({
'client':
httpx.AsyncClient(timeout=None,
base_url=prefiller_base_url,
limits=limit),
'host':
host,
'port':
port,
'id':
i
})
# Create decode clients
for i, (host, port) in enumerate(global_args.decoder_instances):
decoder_base_url = f'http://{host}:{port}/v1'
app.state.decode_clients.append({
'client':
httpx.AsyncClient(timeout=None,
base_url=decoder_base_url,
limits=limit),
'host':
host,
'port':
port,
'id':
i
})
# Initialize round-robin iterators
app.state.prefill_iterator = itertools.cycle(
range(len(app.state.prefill_clients)))
app.state.decode_iterator = itertools.cycle(
range(len(app.state.decode_clients)))
print(f"Initialized {len(app.state.prefill_clients)} prefill clients "
f"and {len(app.state.decode_clients)} decode clients.")
yield
# Shutdown: Close all clients
for client_info in app.state.prefill_clients:
await client_info['client'].aclose()
for client_info in app.state.decode_clients:
await client_info['client'].aclose()
# Update FastAPI app initialization to use lifespan
app = FastAPI(lifespan=lifespan)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--host", type=str, default="localhost")
# For prefiller instances
parser.add_argument("--prefiller-hosts",
"--prefiller-host",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--prefiller-ports",
"--prefiller-port",
type=int,
nargs="+",
default=[8100])
# For decoder instances
parser.add_argument("--decoder-hosts",
"--decoder-host",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--decoder-ports",
"--decoder-port",
type=int,
nargs="+",
default=[8200])
args = parser.parse_args()
# Validate and pair hosts with ports
if len(args.prefiller_hosts) != len(args.prefiller_ports):
raise ValueError(
"Number of prefiller hosts must match number of prefiller ports")
if len(args.decoder_hosts) != len(args.decoder_ports):
raise ValueError(
"Number of decoder hosts must match number of decoder ports")
# Create tuples of (host, port) for each service type
args.prefiller_instances = list(
zip(args.prefiller_hosts, args.prefiller_ports))
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
return args
def get_next_client(app, service_type: str):
"""
Get the next client in round-robin fashion.
Args:
app: The FastAPI app instance
service_type: Either 'prefill' or 'decode'
Returns:
The next client to use
"""
if service_type == 'prefill':
client_idx = next(app.state.prefill_iterator)
return app.state.prefill_clients[client_idx]
elif service_type == 'decode':
client_idx = next(app.state.decode_iterator)
return app.state.decode_clients[client_idx]
else:
raise ValueError(f"Unknown service type: {service_type}")
async def send_request_to_service(client_info: dict, endpoint: str,
req_data: dict, request_id: str):
"""
Send a request to a service using a client from the pool.
"""
req_data = req_data.copy()
req_data['kv_transfer_params'] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": None,
"remote_port": None
}
req_data["stream"] = False
req_data["max_tokens"] = 1
if "stream_options" in req_data:
del req_data["stream_options"]
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
response = await client_info['client'].post(endpoint,
json=req_data,
headers=headers)
response.raise_for_status()
return response
async def stream_service_response(client_info: dict, endpoint: str,
req_data: dict, request_id: str):
"""
Asynchronously stream response from a service using a client from the pool.
"""
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
async with client_info['client'].stream("POST",
endpoint,
json=req_data,
headers=headers) as response:
response.raise_for_status()
async for chunk in response.aiter_bytes():
yield chunk
async def _handle_completions(api: str, request: Request):
try:
req_data = await request.json()
request_id = str(uuid.uuid4())
# Get the next prefill client in round-robin fashion
prefill_client_info = get_next_client(request.app, 'prefill')
# Send request to prefill service
response = await send_request_to_service(prefill_client_info, api,
req_data, request_id)
# Extract the needed fields
response_json = response.json()
kv_transfer_params = response_json.get('kv_transfer_params', {})
if kv_transfer_params:
req_data["kv_transfer_params"] = kv_transfer_params
# Get the next decode client in round-robin fashion
decode_client_info = get_next_client(request.app, 'decode')
logger.debug("Using %s %s", prefill_client_info, decode_client_info)
# Stream response from decode service
async def generate_stream():
async for chunk in stream_service_response(decode_client_info,
api,
req_data,
request_id=request_id):
yield chunk
return StreamingResponse(generate_stream(),
media_type="application/json")
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server"
f" - {api} endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
@app.post("/v1/completions")
async def handle_completions(request: Request):
return await _handle_completions("/completions", request)
@app.post("/v1/chat/completions")
async def handle_chat_completions(request: Request):
return await _handle_completions("/chat/completions", request)
@app.get("/healthcheck")
async def healthcheck():
"""Simple endpoint to check if the server is running."""
return {
"status": "ok",
"prefill_instances": len(app.state.prefill_clients),
"decode_instances": len(app.state.decode_clients)
}
if __name__ == '__main__':
global global_args
global_args = parse_args()
import uvicorn
uvicorn.run(app, host=global_args.host, port=global_args.port)

View File

@@ -23,12 +23,18 @@ Run 'pytest tests/multicard/test_fused_moe_allgather_ep.py'.
import os
from unittest.mock import patch
import pytest
from modelscope import snapshot_download # type: ignore
from vllm import SamplingParams
from tests.e2e.conftest import VllmRunner
@pytest.mark.skipif(
True,
reason=
"Current disaggregated pd implementation may cause memory pulse, which will cause this test OOM, skip this test until the ringmla is ready "
)
@patch.dict(
os.environ, {
"VLLM_WORKER_MULTIPROC_METHOD": "spawn",
@@ -54,6 +60,11 @@ def test_generate_with_allgather():
vllm_model.generate(example_prompts, sampling_params)
@pytest.mark.skipif(
True,
reason=
"Current disaggregated pd implementation may cause memory pulse, which will cause this test OOM, skip this test until the ringmla is ready "
)
@patch.dict(os.environ, {
"VLLM_WORKER_MULTIPROC_METHOD": "spawn",
"TASK_QUEUE_ENABLE": "1"

View File

@@ -23,6 +23,7 @@ Run `pytest tests/test_offline_inference.py`.
import os
from unittest.mock import patch
import pytest
from modelscope import snapshot_download # type: ignore
from vllm import SamplingParams
from vllm.model_executor.models.registry import ModelRegistry
@@ -93,6 +94,10 @@ def test_models_distributed_DeepSeek_dbo():
vllm_model.generate(example_prompts, sampling_params)
@pytest.mark.skip(
reason=
"deepseek dbo dose not consider the support on half precision float, will enable this ut after we actually support it"
)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"})
def test_models_distributed_DeepSeekV3_dbo():
example_prompts = ["The president of the United States is"] * 41
@@ -113,6 +118,7 @@ def test_models_distributed_DeepSeekV3_dbo():
vllm_model.generate(example_prompts, sampling_params)
@pytest.mark.skip(reason="Due to OOM,waiting for 1311pr to merge in")
def test_models_distributed_DeepSeek_W8A8():
example_prompts = [
"Hello, my name is",

View File

@@ -0,0 +1,141 @@
#!/bin/bash
export LCCL_DETERMINISTIC=1
export HCCL_DETERMINISTIC=true
export CLOSE_MATMUL_K_SHIFT=1
export VLLM_USE_V1=1
set -xe
# Models to run
MODELS=(
"Qwen/Qwen3-0.6B-Instruct"
)
# Find the git repository root directory
GIT_ROOT=$(git rev-parse --show-toplevel)
# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
# Gen ranktable
RANKTABLE_PATH=${GIT_ROOT}/examples/disaggregate_prefill_v1/ranktable.json
if [ -f "$RANKTABLE_PATH" ]; then
rm "$RANKTABLE_PATH"
fi
cd ${GIT_ROOT}/examples/disaggregate_prefill_v1
LOCAL_HOST=`hostname -I|awk -F " " '{print$1}'`
bash gen_ranktable.sh --ips $LOCAL_HOST --network-card-name enp189s0f0 --prefill-device-cnt 1 --decode-device-cnt 1
cd -
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH="$RANKTABLE_PATH"
# Waits for vLLM to start.
wait_for_server() {
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/health > /dev/null; do
sleep 1
done" && return 0 || return 1
}
# Function to clean up previous instances
cleanup_instances() {
echo "Cleaning up any running vLLM instances..."
pkill -f "vllm serve" || true
sleep 2
}
# Handle to get model-specific arguments for deepseek
get_model_args() {
local model_name=$1
local extra_args=""
if [[ "$model_name" == *"deepseek"* ]]; then
extra_args="--trust-remote-code"
fi
echo "$extra_args"
}
# Function to run tests for a specific model
run_tests_for_model() {
local model_name=$1
echo "================================"
echo "Testing model: $model_name"
echo "================================"
# Get model-specific arguments
local model_args=$(get_model_args "$model_name")
# Start prefill instance
PREFILL_PORT=8001
BASE_CMD="ASCEND_RT_VISIBLE_DEVICES=0 VLLM_LLMDD_RPC_PORT=5559 vllm serve $model_name \
--port $PREFILL_PORT \
--seed 1024 \
--enforce-eager \
--disable-log-requests \
--gpu-memory-utilization 0.8 \
--kv-transfer-config '{\"kv_connector\":\"LLMDataDistCMgrConnector\",\"kv_role\":\"kv_producer\",\"kv_buffer_device\":\"npu\",\"kv_parallel_size\":\"1\",\"kv_port\":\"20001\",\"engine_id\":\"0\",\"kv_connector_module_path\":\"vllm_ascend.distributed.llmdatadist_c_mgr_connector\"}'"
if [ -n "$model_args" ]; then
FULL_CMD="$BASE_CMD $model_args"
else
FULL_CMD="$BASE_CMD"
fi
eval "$FULL_CMD &"
# Start decode instance
DECODE_PORT=8002
# Build the command with or without model-specific args
BASE_CMD="ASCEND_RT_VISIBLE_DEVICES=1 VLLM_LLMDD_RPC_PORT=6000 vllm serve $model_name \
--port $DECODE_PORT \
--seed 1024 \
--enforce-eager \
--disable-log-requests \
--gpu-memory-utilization 0.8 \
--kv-transfer-config '{\"kv_connector\":\"LLMDataDistCMgrConnector\",\"kv_role\":\"kv_consumer\",\"kv_buffer_device\":\"npu\",\"kv_parallel_size\":\"1\",\"kv_port\":\"20001\",\"engine_id\":\"0\",\"kv_connector_module_path\":\"vllm_ascend.distributed.llmdatadist_c_mgr_connector\"}'"
if [ -n "$model_args" ]; then
FULL_CMD="$BASE_CMD $model_args"
else
FULL_CMD="$BASE_CMD"
fi
eval "$FULL_CMD &"
# Wait for all instances to start
echo "Waiting for prefill instance on port $PORT to start..."
wait_for_server $PREFILL_PORT
echo "Waiting for decode instance on port $PORT to start..."
wait_for_server $DECODE_PORT
# Build the command for the proxy server with all the hosts and ports
PROXY_PORT=8192
PROXY_CMD="python ${GIT_ROOT}/examples/disaggregate_prefill_v1/toy_proxy_server.py --port $PROXY_PORT"
PROXY_CMD+=" --prefiller-ports ${PREFILL_PORT}"
PROXY_CMD+=" --decoder-ports ${DECODE_PORT}"
# Start the proxy server
echo "Starting proxy server with command: $PROXY_CMD"
$PROXY_CMD &
# Wait for the proxy to start
sleep 5
# Run lm eval for this model
echo "Running tests for $model_name"
PREFILL_PORT=$PREFILL_PORT DECODE_PORT=$DECODE_PORT PROXY_PORT=$PROXY_PORT python -m pytest -s -v ${GIT_ROOT}/tests/e2e/pd_disaggreate/test_edge_cases.py
# Clean up before running next model
cleanup_instances
sleep 3
}
# Run tests for each model
for model in "${MODELS[@]}"; do
run_tests_for_model "$model"
done
echo "All tests completed!"

View File

@@ -0,0 +1,81 @@
# SPDX-License-Identifier: Apache-2.0
# This code is from: https://github.com/vllm-project/vllm/blob/main/tests/v1/kv_connector/nixl_integration/test_edge_cases.py
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
import os
import openai
PREFILL_PORT = os.getenv("PREFILL_PORT", None)
DECODE_PORT = os.getenv("DECODE_PORT", None)
PROXY_PORT = os.getenv("PROXY_PORT", None)
if PREFILL_PORT is None or DECODE_PORT is None or PROXY_PORT is None:
raise ValueError(
"Please set the PREFILL_PORT, DECODE_PORT, and PROXY_PORT.")
LONG_PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result, when working on projects like vLLM we are able to meet many amazing people from various organizations like AMD, Google, NVIDIA, " # noqa: E501
PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result," # noqa: E501
SHORT_PROMPT = "Red Hat is "
def test_edge_cases():
# Set the OpenAI API key and base URL
decode_client = openai.OpenAI(
api_key="MY_KEY",
base_url=f"http://localhost:{DECODE_PORT}/v1",
)
prefill_client = openai.OpenAI(
api_key="MY_KEY",
base_url=f"http://localhost:{PREFILL_PORT}/v1",
)
proxy_client = openai.OpenAI(
api_key="MY_KEY",
base_url=f"http://localhost:{PROXY_PORT}/v1",
)
# Get the list of models
models = decode_client.models.list()
MODEL = models.data[0].id
# (1) Check that we can handle a very short prompt,
# less than the length of the block size.
completion = proxy_client.completions.create(model=MODEL,
prompt=SHORT_PROMPT,
temperature=0)
proxy_response = completion.choices[0].text
completion = prefill_client.completions.create(model=MODEL,
prompt=SHORT_PROMPT,
temperature=0)
prefill_response = completion.choices[0].text
print(f"SMALL PROMPT: {proxy_response=}")
print(f"SMALL PROMPT: {prefill_response=}")
assert proxy_response == prefill_response
# (2) Check that we can handle a full prefix cache
# hit on the D worker but not on the P worker.
# (2a): prime the D worker.
completion = decode_client.completions.create(model=MODEL,
prompt=PROMPT,
temperature=0)
decode_response = completion.choices[0].text
# (2b): send via the P/D setup
completion = proxy_client.completions.create(model=MODEL,
prompt=PROMPT,
temperature=0)
proxy_response = completion.choices[0].text
print(f"FULL CACHE HIT: {proxy_response=}")
assert proxy_response == decode_response
# (3) Check that we can handle a partial prefix cache
# hit on the D worker.
completion = proxy_client.completions.create(model=MODEL,
prompt=LONG_PROMPT,
temperature=0)
proxy_response = completion.choices[0].text
completion = prefill_client.completions.create(model=MODEL,
prompt=LONG_PROMPT,
temperature=0)
prefill_response = completion.choices[0].text
print(f"PARTIAL CACHE HIT: {proxy_response=}")
assert proxy_response == prefill_response

View File

@@ -251,7 +251,10 @@ class TestAscendAttentionBackendImpl(TestBase):
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.ones(1, 1, 10, 8, 64, dtype=torch.int8)
k_cache = torch.ones(1, 10, 8, 64, dtype=torch.int8)
v_cache = torch.ones(1, 10, 8, 64, dtype=torch.int8)
kv_cache = [k_cache, v_cache]
ret_value = torch.ones(1, 1, 10, 8, 64, dtype=torch.int8)
metadata = MagicMock()
metadata.num_actual_tokens = torch.randn(10, 8 * 64)
@@ -261,7 +264,7 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.query_lens = torch.randn(10, 8 * 64)
layer = self.layer
layer.quant_method = MagicMock()
layer.quant_method.apply.return_value = kv_cache
layer.quant_method.apply.return_value = ret_value
output = self.impl.forward(layer,
query,

View File

@@ -0,0 +1,42 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
from tests.ut.kv_connector.utils import (create_request, create_scheduler,
create_vllm_config)
from vllm_ascend.distributed.llmdatadist_c_mgr_connector import \
LLMDataDistCMgrConnectorMetadata
def test_basic_inferface():
"""Unit test for basic LLMDataDistCMgrConnector interface functionality."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
request_id = request.request_id
scheduler.add_request(request)
# Remote Prefill, triggers LLMDataDistCMgrConnectorMetadata.
scheduler_output = scheduler.schedule()
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None
assert isinstance(kv_connector_metadata, LLMDataDistCMgrConnectorMetadata)
assert len(kv_connector_metadata.requests) == 1
assert request_id in kv_connector_metadata.requests
req_meta = kv_connector_metadata.requests[request_id]
for block_id, block in zip(
req_meta.local_block_ids, scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks[request_id]):
assert block_id == block.block_id

View File

@@ -0,0 +1,163 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/blob/main/tests/conftest.py
#
import copy
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
from vllm.v1.request import FinishReason, RequestStatus
from tests.ut.kv_connector.utils import (assert_scheduler_empty,
create_model_runner_output,
create_request, create_scheduler,
create_vllm_config)
def test_basic_lifecycle():
"""Test lifecycle of a Remote Decode request."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(request_id=1,
max_tokens=1,
num_tokens=NUM_TOKENS,
do_remote_decode=True)
scheduler.add_request(request)
request_id = request.request_id
# STEP (1): Prefill.
# (1a): schedule()
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1
# (1b): execute_model()
model_runner_output = create_model_runner_output(reqs=[request])
# (1c): update_from_output()
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
# Ensure the request is finished after 1 tokens.
assert request.is_finished()
assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED
output = engine_core_outputs[0].outputs[0]
assert output.finish_reason == FinishReason.LENGTH
assert output.kv_transfer_params is not None
# Request freed in Scheduler and blocks should be freed
assert request_id in scheduler.finished_req_ids
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 0
# ... but blocks should not be freed.
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks[request_id]
for block in blocks:
assert block.ref_cnt == 1
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 0
assert len(scheduler_output.finished_req_ids) == 1
assert request_id in scheduler_output.finished_req_ids
assert len(scheduler_output.scheduled_new_reqs) == 0
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler.finished_req_ids) == 0
# (2b): execute_model()
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
# (2c): update_from_output()
scheduler.update_from_output(scheduler_output, model_runner_output)
# STEP (3): Finished sending.
# (3a): schedule() - pass finished request to PB.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 0
assert len(scheduler_output.finished_req_ids) == 0
assert len(scheduler_output.scheduled_new_reqs) == 0
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler.finished_req_ids) == 0
# (3b): execute_model()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.finished_sending = [request_id]
# (3c): update_from_output()
scheduler.update_from_output(scheduler_output, model_runner_output)
# Confirm we do not have any memory leaks after req lifecycle.
assert_scheduler_empty(scheduler)
def test_prefix_cache_lifecycle():
"""Test that remote decode params still works with a prefix cache hit."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# Prime the KVCache.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 3
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_remote_a = create_request(request_id=1, num_tokens=NUM_TOKENS)
scheduler.add_request(request_remote_a)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_remote_a],
use_eos=True)
scheduler.update_from_output(scheduler_output, model_runner_output)
scheduler.schedule()
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
#####################
# Actual Test: confirm we send all blocks.
# Step (1): Send the KV Transfer.
NUM_EXTERNAL_FULL_BLOCKS -= 1
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_remote = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_decode=True)
scheduler.add_request(request_remote)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_remote])
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
kv_transfer_params = eco[0].outputs[0].kv_transfer_params
# Ensure we send all block ids, even if there is a cache hit.
assert (len(
kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS +
1))
# STEP (2): Ensure it is freed.
scheduler_output = scheduler.schedule()
scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.finished_sending = [request_remote.request_id]
scheduler.update_from_output(scheduler_output, model_runner_output)
_ = scheduler.schedule()
assert_scheduler_empty(scheduler)

View File

@@ -0,0 +1,248 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/blob/main/tests/conftest.py
#
import copy
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
from vllm.v1.request import FinishReason, RequestStatus
from tests.ut.kv_connector.utils import (assert_scheduler_empty,
create_model_runner_output,
create_request, create_scheduler,
create_vllm_config)
from vllm_ascend.utils import vllm_version_is
def test_basic_lifecycle():
"""Test lifecycle of a remote prefill."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
START_FREE_BLOCK_QUEUE_SIZE = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
request = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
scheduler.add_request(request)
request_id = request.request_id
# STEP (1):
# (1a): schedule()
scheduler_output = scheduler.schedule()
# Nothing running and empty scheduler output.
assert len(scheduler.running) == 0
assert len(scheduler_output.scheduled_new_reqs) == 0
if vllm_version_is("0.9.1"):
assert len(scheduler_output.scheduled_cached_reqs) == 0
else:
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler_output.num_scheduled_tokens) == 0
assert scheduler_output.total_num_scheduled_tokens == 0
# Req waiting for KVs with no computed/scheduled toks ...
assert len(scheduler.waiting) == 1
assert request in scheduler.waiting
assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS)
assert (request.num_computed_tokens == 0)
# ... but should have (uncached) blocks allocated to it.
block_pool = scheduler.kv_cache_manager.block_pool
assert (block_pool.free_block_queue.num_free_blocks
< START_FREE_BLOCK_QUEUE_SIZE)
assert len(block_pool.cached_block_hash_to_block) == 0
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks[request_id]
for block in blocks:
assert block._block_hash is None
# (1b): forward()
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
# (1c): update_from_output()
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
assert not engine_core_outputs or not engine_core_outputs[0].outputs
# STEP (2):
# (2a): schedule(): nothing happens!
scheduler_output = scheduler.schedule()
assert len(scheduler.waiting) == 1
assert len(scheduler.running) == 0
# (2b): forward(): request finishes recv.
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.finished_recving = [request_id]
# (2c): update_from_output():
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
assert len(scheduler.waiting) == 1
assert (request_id in scheduler.finished_recving_kv_req_ids)
# STEP (3):
# (3a): schedule(): this should actually schedule.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 1
# Confirm the block are actually allocated.
num_hashed_blocks = 0
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks[request_id]
for block in blocks:
assert block.ref_cnt == 1
num_hashed_blocks += (1 if block._block_hash is not None else 0)
assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS
# Confirm the rest of the prompt is scheduled in this step.
scheduled_req = scheduler_output.scheduled_new_reqs[0]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id]
num_computed_tokens = scheduled_req.num_computed_tokens
total_prompt_tokens = len(scheduled_req.prompt_token_ids)
assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens)
# (3b): execute_model()
model_runner_output = create_model_runner_output([request])
# (3c): update_from_output()
scheduler.update_from_output(scheduler_output, model_runner_output)
# Step (4): Hit EOS.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output([request], use_eos=True)
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
scheduler.schedule()
if vllm_version_is("0.9.1"):
outputs = engine_core_outputs[0].outputs
assert len(outputs) == 1
output = outputs[0]
assert output.finish_reason == FinishReason.STOP
assert_scheduler_empty(scheduler)
def test_no_spurious_prefix_caching():
"""
With P/D, blocks can be allocated but uncomputed for
multiple engine steps. This test confirms that we do
not accidentally have cache hits against uncomputed
blocks.
"""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 and a half full external blocks.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
# Both of these requests have prompts like [1,1,1,1,1, ...]
request_remote = create_request(
request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True,
use_all_1s_for_prompt_tokens=True,
)
# Schedule the remote prefill request. This should not
# cause any blocks to be cached.
scheduler.add_request(request_remote)
scheduler_output = scheduler.schedule()
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
assert len(scheduler.waiting) == 1
remote_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks[request_remote.request_id]
# Remote blocks should not be cached.
for block in remote_blocks:
assert block.ref_cnt == 1
assert block._block_hash is None
def test_full_block_prompt():
"""Test that we handle a prompt that is the full block size."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS)
request = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
scheduler.add_request(request)
request_id = request.request_id
# STEP (1): Initialize a recv.
scheduler_output = scheduler.schedule()
# All blocks should be allocated.
num_blocks = len(scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks[request_id])
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
scheduler.update_from_output(scheduler_output, model_runner_output)
# # STEP (2): Recv.
scheduler_output = scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.finished_recving = [request_id]
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.waiting) == 1
assert (request_id in scheduler.finished_recving_kv_req_ids)
# # STEP (3): Run as usual.
scheduler_output = scheduler.schedule()
# We need to recompute the final token of the prompt to generate
# the first new token, so we should not have a new block.
num_blocks = len(scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks[request_id])
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens ==
NUM_TOKENS - 1)
assert (scheduler_output.num_scheduled_tokens[request_id] == 1)
model_runner_output = create_model_runner_output([request])
scheduler.update_from_output(scheduler_output, model_runner_output)
# # Step (4): Hit EOS.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output([request], use_eos=True)
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
scheduler.schedule()
if vllm_version_is("0.9.1"):
outputs = engine_core_outputs[0].outputs
assert len(outputs) == 1
output = outputs[0]
assert output.finish_reason == FinishReason.STOP
assert_scheduler_empty(scheduler)

View File

@@ -0,0 +1,201 @@
# SPDX-License-Identifier: Apache-2.0
# This code is from: https://github.com/vllm-project/vllm/tests/v1/kv_connector/unit/utils.py
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
import os
from typing import Any, Optional
import torch
from vllm import SamplingParams
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
ModelConfig, SchedulerConfig, VllmConfig)
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request
from vllm.v1.structured_output import StructuredOutputManager
from vllm_ascend.utils import vllm_version_is
EOS_TOKEN_ID = 50256
os.environ["VLLM_USE_V1"] = "1"
def assert_scheduler_empty(scheduler: Scheduler):
"""Confirm the scheduler is "empty" - i.e. no leaks."""
# Scheduler Metadata.
assert len(scheduler.requests) == 0
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 0
assert len(scheduler.finished_req_ids) == 0
assert len(scheduler.finished_recving_kv_req_ids) == 0
# EncoderCacheManager.
assert len(scheduler.encoder_cache_manager.freed) == 0
assert len(scheduler.encoder_cache_manager.cached) == 0
# KVCache Manager.
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block) == 0
num_free_blocks = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
assert num_free_blocks == (
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
# NOTE(rob): just the ref count on blocks will be 0. The hash
# value, etc will remain since we lazily evict for prefix cache.
for block in scheduler.kv_cache_manager.block_pool.blocks:
assert block.ref_cnt == 0
def create_vllm_config(
model: str = "facebook/opt-125m",
max_num_seqs: int = 16,
max_num_batched_tokens: int = 1024,
block_size: int = 128,
) -> VllmConfig:
"""Initialize VllmConfig For Testing."""
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_num_batched_tokens,
)
model_config = ModelConfig(
model=model,
task="auto",
tokenizer=model,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="float16",
seed=42,
)
# Cache config, optionally force APC
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
enable_prefix_caching=True,
)
kv_transfer_config = KVTransferConfig(
kv_connector="LLMDataDistCMgrConnector",
kv_role="kv_both",
kv_connector_module_path=
"vllm_ascend.distributed.llmdatadist_c_mgr_connector")
return VllmConfig(scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
device_config=DeviceConfig("cpu"))
def create_scheduler(
vllm_config: VllmConfig,
num_blocks: int = 10000,
) -> Scheduler:
"""Initialize Scheduler For Testing."""
block_size = vllm_config.cache_config.block_size
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float16,
False))
],
)
vllm_config.cache_config.num_gpu_blocks = num_blocks
return Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
)
def create_request(
request_id: int,
num_tokens: int = 10,
max_tokens: int = 128,
do_remote_decode: bool = False,
do_remote_prefill: bool = False,
use_all_1s_for_prompt_tokens: bool = False,
num_remote_blocks: int = 3,
) -> Request:
"""Make dummy request for testing."""
kv_transfer_params: Optional[dict[str, Any]] = None
if do_remote_decode:
assert not do_remote_prefill
kv_transfer_params = dict(do_remote_prefill=False,
do_remote_decode=True)
elif do_remote_prefill:
kv_transfer_params = dict(do_remote_prefill=True,
do_remote_decode=False,
remote_engine_id="my-engine-id",
remote_block_ids=list(
range(num_remote_blocks)),
remote_host="my-host",
remote_port=1234,
remote_tp_size=1)
max_tokens = 1 if do_remote_decode else max_tokens
sampling_params = SamplingParams(max_tokens=max_tokens)
if use_all_1s_for_prompt_tokens:
prompt_token_ids = [1] * num_tokens
else:
prompt_token_ids = [i * request_id for i in range(num_tokens)]
req = Request(
request_id=f"id-{request_id}",
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
multi_modal_inputs=None,
multi_modal_placeholders=None,
multi_modal_hashes=None,
**({
"pooling_params": []
} if not vllm_version_is("0.9.1") else {}),
eos_token_id=EOS_TOKEN_ID,
)
req.kv_transfer_params = kv_transfer_params
return req
def create_model_runner_output(
reqs: list[Request],
finished_sending: Optional[list[str]] = None,
finished_recving: Optional[list[str]] = None,
use_eos: bool = False,
) -> ModelRunnerOutput:
"""Make dummy model runner output for testing."""
# Make request data.
req_ids = [req.request_id for req in reqs]
req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)}
# Make sampled tokens.
sampled_token = EOS_TOKEN_ID if use_eos else 0
sampled_token_ids = [[sampled_token] for _ in req_ids]
# Make output data structure.
return ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids,
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}),
finished_sending=finished_sending,
finished_recving=finished_recving,
)

View File

@@ -252,7 +252,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
kv_cache: Tuple[torch.Tensor],
attn_metadata: AscendMetadata,
output: Optional[torch.Tensor] = None,
trace_flag: bool = True,
@@ -262,8 +262,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
kv_cache: shape = [2, num_blocks, block_size,
num_kv_heads, head_size]
kv_cache: shape = [key_cache, value_cache]
key_cache = [num_blocks, block_size,
num_kv_heads, head_size]
value_cache = [num_blocks, block_size,
@@ -273,8 +272,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
shape = [batch_size * seq_len, num_heads, head_size]
"""
num_tokens = query.shape[0]
use_kv_cache_int8 = kv_cache.numel(
) > 0 and kv_cache[0].dtype == torch.int8
use_kv_cache_int8 = len(
kv_cache) > 0 and kv_cache[0].dtype == torch.int8
if output is None:
output = torch.empty(num_tokens,
self.num_heads,
@@ -314,7 +313,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
# TODO: Remove this contiguous in the future.
value = value.contiguous()
if kv_cache.numel() > 0:
if len(kv_cache) > 1:
if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping

View File

@@ -62,7 +62,7 @@ class AscendAttentionTorchairBackend(AttentionBackend):
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (num_blocks, block_size, num_kv_heads * head_size)
return (2, num_blocks, block_size, num_kv_heads * head_size)
@staticmethod
def get_bsh_kv_cache_shape(
@@ -71,7 +71,7 @@ class AscendAttentionTorchairBackend(AttentionBackend):
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (num_blocks, block_size, num_kv_heads * head_size)
return (2, num_blocks, block_size, num_kv_heads * head_size)
@staticmethod
def swap_blocks(

View File

@@ -14,6 +14,7 @@ from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.utils import cdiv, round_down
from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
@@ -648,12 +649,13 @@ class AscendMLAImpl(MLAAttentionImpl):
def _compute_prefill_context(
self,
query: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
rope_dim: int,
attn_metadata: AscendMLAMetadata,
prefix_output: torch.Tensor,
prefix_lse: torch.Tensor,
):
assert len(kv_c_and_k_pe_cache) > 1
prefill_metadata = attn_metadata.prefill
if prefill_metadata is None or prefill_metadata.chunked_context is None:
return prefix_output, prefix_lse
@@ -663,21 +665,22 @@ class AscendMLAImpl(MLAAttentionImpl):
q_nope = query[..., :self.qk_nope_head_dim]
seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
latent_kv_dim = kv_c_and_k_pe_cache.size(3) - rope_dim
cache_kv_c = kv_c_and_k_pe_cache[:, :, :, :latent_kv_dim]
cache_k_pe = kv_c_and_k_pe_cache[:, :, :, latent_kv_dim:]
cache_kv_c = kv_c_and_k_pe_cache[0]
cache_k_pe = kv_c_and_k_pe_cache[1]
num_heads = cache_k_pe.size(2)
latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1)
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
seq_len = torch.stack([seq_len1, seq_len2])
kv_c_normed = torch.empty(toks,
kv_c_and_k_pe_cache.size(2),
num_heads,
latent_kv_dim,
dtype=query.dtype,
device=query.device)
k_pe = torch.empty(toks,
kv_c_and_k_pe_cache.size(2),
num_heads,
rope_dim,
dtype=query.dtype,
device=query.device)
@@ -727,10 +730,11 @@ class AscendMLAImpl(MLAAttentionImpl):
query: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
attn_metadata: AscendMLAMetadata,
) -> torch.Tensor:
assert attn_metadata.prefill is not None
assert len(kv_c_and_k_pe_cache) > 1
num_tokens = query.size(0)
attn_output = torch.empty(num_tokens,
@@ -923,19 +927,13 @@ class AscendMLAImpl(MLAAttentionImpl):
q_pe: torch.Tensor,
k_nope: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
attn_metadata: AscendMLAMetadata,
enable_multistream_mla: bool = False,
) -> torch.Tensor:
decode_meta = attn_metadata.decode
assert decode_meta is not None
q = torch.cat([q_nope, q_pe], dim=-1)
num_tokens = q.size(0)
attn_output = torch.empty(
[num_tokens, self.num_heads, self.kv_lora_rank],
dtype=q.dtype,
device=q.device)
num_tokens = q_nope.size(0)
if self.running_in_graph:
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
@@ -994,16 +992,35 @@ class AscendMLAImpl(MLAAttentionImpl):
actual_seq_lengths_kv=decode_meta.seq_lens_list,
)
else:
torch_npu._npu_paged_attention_mla(
query=q,
key_cache=kv_c_and_k_pe_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.decode.block_table, # type:ignore
context_lens=attn_metadata.decode.seq_lens, # type:ignore
mla_vheadsize=self.kv_lora_rank,
out=attn_output)
# The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will
# be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become
# public available
assert len(kv_c_and_k_pe_cache) > 1
if envs.VLLM_ASCEND_MLA_PA:
attn_output = torch_npu.atb.npu_multi_head_latent_attention(
q_nope, q_pe, kv_c_and_k_pe_cache[0],
kv_c_and_k_pe_cache[1], attn_metadata.decode.block_table,
attn_metadata.decode.seq_lens, self.num_heads, self.scale,
self.num_kv_heads)
else:
q = torch.cat([q_nope, q_pe], dim=-1)
attn_output = torch.empty(
[num_tokens, self.num_heads, self.kv_lora_rank],
dtype=q.dtype,
device=q.device)
k_cache = torch.cat(
[kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1)
torch_npu._npu_paged_attention_mla(
query=q,
key_cache=k_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.decode.
block_table, # type:ignore
context_lens=attn_metadata.decode.seq_lens, # type:ignore
mla_vheadsize=self.kv_lora_rank,
out=attn_output)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self._v_up_proj_and_o_proj(attn_output,
@@ -1020,7 +1037,7 @@ class AscendMLAImpl(MLAAttentionImpl):
hidden_states_or_q_c: torch.Tensor, # query in unified attn
hidden_states_or_kv_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
kv_cache: Tuple[torch.Tensor],
attn_metadata: M,
output: Optional[torch.Tensor] = None,
enable_multistream_mla: bool = False,
@@ -1151,8 +1168,12 @@ class AscendMLAImpl(MLAAttentionImpl):
prefill_q_pe.contiguous(),
prefill_k_pe,
max_seq_len=attn_metadata.prefill.max_seq_lens)
assert len(
kv_cache
) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
if self.torchair_graph_enabled:
if len(kv_cache) > 0 and kv_cache[0].numel(
if kv_cache[0].numel(
) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
slots = attn_metadata.slot_mapping
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
@@ -1162,16 +1183,15 @@ class AscendMLAImpl(MLAAttentionImpl):
key_cache=kv_cache[0],
value_cache=kv_cache[1],
slot_indices=slots)
elif kv_cache.numel() > 0:
key = torch.cat([
kv_c_normed.view([num_actual_toks, self.num_kv_heads, -1]),
k_pe
],
dim=2)
torch_npu._npu_reshape_and_cache_siso(
key=key,
key_cache=kv_cache,
slot_indices=attn_metadata.slot_mapping.flatten())
else:
kv_c_normed = kv_c_normed.view(
[num_actual_toks, self.num_kv_heads, -1])
torch_npu._npu_reshape_and_cache(
key=kv_c_normed,
value=k_pe,
key_cache=kv_cache[0],
value_cache=kv_cache[1],
slot_indices=attn_metadata.slot_mapping)
if has_prefill:
# FIX: aicore move should be also placed on the comm stream in dbo,
# otherwise it may affect the accuracy

View File

@@ -23,7 +23,6 @@ from vllm.distributed.kv_events import KVEventBatch
from vllm.logger import logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.utils import cdiv
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutputs
@@ -87,14 +86,11 @@ class AscendScheduler(Scheduler):
self.waiting.popleft()
skipped_waiting_requests.appendleft(request)
num_prealloc_computed_tokens = 0
# P/D: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request)
if is_ready:
request.status = RequestStatus.WAITING
num_prealloc_computed_tokens = (
request.num_computed_tokens)
else:
skip_cur_request()
continue
@@ -112,8 +108,8 @@ class AscendScheduler(Scheduler):
load_kv_async = False
# Get already-cached tokens.
if num_prealloc_computed_tokens == 0:
new_computed_blocks, num_native_computed_tokens = \
if request.num_computed_tokens == 0:
new_computed_blocks, num_new_local_computed_tokens = \
self.kv_cache_manager.get_computed_blocks(
request)
@@ -121,18 +117,17 @@ class AscendScheduler(Scheduler):
if self.connector is not None:
num_external_computed_tokens, load_kv_async = (
self.connector.get_num_new_matched_tokens(
request, num_native_computed_tokens))
request, num_new_local_computed_tokens))
# Total computed tokens (local + external).
num_computed_tokens = (num_native_computed_tokens +
num_computed_tokens = (num_new_local_computed_tokens +
num_external_computed_tokens)
else:
# P/D: skip checking prefix cache if loaded from remote kvs.
new_computed_blocks = KVCacheBlocks.create_empty()
num_native_computed_tokens = 0
# Total computed tokens (allocated in prior step).
num_computed_tokens = num_prealloc_computed_tokens
new_computed_blocks = (
self.kv_cache_manager.create_empty_block_list())
num_new_local_computed_tokens = 0
num_computed_tokens = request.num_computed_tokens
# P/D: loading remote KV, do not allocate for new work.
if load_kv_async:
@@ -142,9 +137,6 @@ class AscendScheduler(Scheduler):
# Number of tokens to be scheduled.
else:
prompt_limit = self._get_prompt_limit(request)
# Get already-cached tokens.
computed_blocks, num_computed_tokens = (
self.kv_cache_manager.get_computed_blocks(request))
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens.
@@ -172,7 +164,7 @@ class AscendScheduler(Scheduler):
skip_cur_request()
continue
assert num_new_tokens > 0
blocks = computed_blocks.blocks[0]
blocks = new_computed_blocks.blocks[0]
watermark = getattr(self.scheduler_config, "watermark", 0.01)
if not self._check_watermark_for_prefill(request, num_new_tokens,
@@ -184,8 +176,8 @@ class AscendScheduler(Scheduler):
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens + num_external_computed_tokens,
num_native_computed_tokens,
new_computed_blocks=computed_blocks,
num_new_local_computed_tokens,
new_computed_blocks=new_computed_blocks,
num_lookahead_tokens=self.num_lookahead_tokens,
delay_cache_blocks=load_kv_async)
if new_blocks is None:
@@ -195,8 +187,7 @@ class AscendScheduler(Scheduler):
# KVConnector: update internal state after allocation.
# This information is used to determine if a load is
# needed for this request.
if num_external_computed_tokens:
assert self.connector is not None
if self.connector is not None:
self.connector.update_state_after_alloc(
request,
new_computed_blocks + new_blocks,
@@ -210,6 +201,7 @@ class AscendScheduler(Scheduler):
skipped_waiting_requests.appendleft(request)
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue
self.running.append(request)
if self.log_stats:
request.record_event(EngineCoreEventType.SCHEDULED,

View File

@@ -25,3 +25,8 @@ KVConnectorFactory.register_connector(
KVConnectorFactory.register_connector(
"AscendSimpleConnector",
"vllm_ascend.distributed.kv_transfer.simple_connector", "SimpleConnector")
KVConnectorFactory.register_connector(
"LLMDataDistCMgrConnector",
"vllm_ascend.distributed.llmdatadist_c_mgr_connector",
"LLMDataDistCMgrConnector")

View File

@@ -0,0 +1,883 @@
import contextlib
import json
import math
import os
import threading
import time
from collections import defaultdict
from collections.abc import Iterator
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum
from typing import Any, Optional, Tuple
import llm_datadist # type: ignore
import msgspec
import torch
import zmq
from llm_datadist import (BlocksCacheKey, CacheDesc, LLMConfig, LLMDataDist,
LLMException, LLMRole)
from vllm.config import KVTransferConfig, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import get_tp_group, get_world_group
from vllm.forward_context import ForwardContext
from vllm.utils import get_ip, logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import Request, RequestStatus
from vllm_ascend import envs
from vllm_ascend.soc_info import NPUSocInfo
TORCH_DTYPE_TO_NPU_DTYPE = {
torch.half: llm_datadist.DataType.DT_FLOAT16,
torch.float16: llm_datadist.DataType.DT_FLOAT16,
torch.bfloat16: llm_datadist.DataType.DT_BF16,
torch.float: llm_datadist.DataType.DT_FLOAT,
torch.float32: llm_datadist.DataType.DT_FLOAT,
torch.int8: llm_datadist.DataType.DT_INT8,
torch.int64: llm_datadist.DataType.DT_INT64,
torch.int32: llm_datadist.DataType.DT_INT32
}
class LLMDataDistCMgrEvent(Enum):
ReqForMetadata = 0
ReqForFinished = 1
class LLMDataDistCMgrAgentMetadata(msgspec.Struct):
super_pod_id: str
server_id: str
device_id: str
device_ip: str
super_device_id: str
cluster_id: int
@dataclass
class ReqMeta:
local_block_ids: list[int]
remote_block_ids: list[int]
remote_host: str
remote_port: str
engine_id: str
remote_tp_size: str
class LLMDataDistCMgrConnectorMetadata(KVConnectorMetadata):
def __init__(self):
self.requests: dict[str, ReqMeta] = {}
def add_new_req(self, request_id: str, local_block_ids: list[int],
kv_transfer_params: dict[str, Any]):
self.requests[request_id] = ReqMeta(
local_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params["remote_block_ids"],
engine_id=kv_transfer_params["remote_engine_id"],
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
remote_tp_size=kv_transfer_params["remote_tp_size"],
)
class LLMDataDistCMgrConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
assert vllm_config.kv_transfer_config is not None
self.engine_id = vllm_config.kv_transfer_config.engine_id
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler: Optional[
LLMDataDistCMgrConnectorScheduler] = LLMDataDistCMgrConnectorScheduler(
vllm_config, self.engine_id)
elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None
self.connector_worker = LLMDataDistCMgrConnectorWorker(vllm_config)
############################################################
# Scheduler Side Methods
############################################################
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens)
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
############################################################
# Worker Side Methods
############################################################
def register_kv_caches(
self,
kv_caches: dict[
str, # type: ignore[override]
Tuple[torch.Tensor]]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""Get the finished recving and sending requests."""
assert self.connector_worker is not None
return self.connector_worker.get_finished(finished_req_ids)
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata,
LLMDataDistCMgrConnectorMetadata)
self.connector_worker.start_load_kv(self._connector_metadata)
def wait_for_layer_load(self, layer_name: str) -> None:
"""LLMDataDistCMgrConnector does not do layerwise saving, the load is in blocking manager."""
pass
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata, **kwargs) -> None:
"""LLMDataDistCMgrConnector does not save explicitly."""
pass
def wait_for_save(self):
"""LLMDataDistCMgrConnector does not save explicitly."""
pass
class LLMDataDistCMgrConnectorScheduler():
def __init__(self, vllm_config: VllmConfig, engine_id: Optional[str]):
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.engine_id = engine_id
self.local_ip = get_ip()
# Can not retrieve the parallel config since it is not initialized.
self.local_dp_rank = None
self.tp_size = None
dp_rank_local = self.vllm_config.parallel_config.data_parallel_rank_local
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
self.port = dp_rank_local * tp_size + envs.VLLM_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs.VLLM_LLMDD_RPC_PORT
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
"""
For remote prefill, pull all prompt blocks from remote
asynchronously relative to engine execution.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
* the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
* true if the external KV cache tokens will be loaded
asynchronously (between scheduler steps).
"""
params = request.kv_transfer_params
logger.debug(
f"LLMDataDistCMgrConnector get_num_new_matched_tokens: num_computed_tokens={num_computed_tokens}, kv_transfer_params={params}"
)
if params is not None and params.get("do_remote_prefill"):
# Remote prefill: get all prompt blocks from remote.
assert num_computed_tokens % self.block_size == 0
# Note: We use the full token count as transmit data here.
count = max(len(request.prompt_token_ids) - num_computed_tokens, 0)
return count, count > 0
# No remote prefill for this request.
return 0, False
def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks,
num_externel_tokens: int):
params = request.kv_transfer_params
logger.debug(
f"LLMDataDistCMgrConnector update states num_externel_tokens: {num_externel_tokens} kv_transfer_params: {params}"
)
if params is not None and params.get("do_remote_prefill"):
if params.get("remote_block_ids"):
if all(p in params for p in ("remote_engine_id", "remote_host",
"remote_port", "remote_tp_size")):
self._reqs_need_recv[request.request_id] = (
request, blocks.get_unhashed_block_ids())
else:
logger.warning("" \
f"Invalid KVTransferParams {params}, This request will be discard")
else:
assert num_externel_tokens == 0
params["do_remote_prefill"] = False
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
meta = LLMDataDistCMgrConnectorMetadata()
for req_id, (req, block_ids) in self._reqs_need_recv.items():
assert req.kv_transfer_params is not None
meta.add_new_req(request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params)
self._reqs_need_recv.clear()
return meta
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
params = request.kv_transfer_params
logger.debug(
"LLMDataDistCMgrConnector request_finished, request_status=%s, "
"kv_transfer_params=%s", request.status, params)
if (params is None or not params.get("do_remote_decode")
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
return False, None
# note: NIXL transfer the full block only, but I don't see any reason to do that, so here
# we just transfer any data that computed from prefill node
# note: there might be some issue on this, check it if there is any unexpected result
computed_block_ids = block_ids
delay_free_blocks = len(computed_block_ids) > 0
if delay_free_blocks:
logger.info("Delaying free of %d blocks for request %s",
len(computed_block_ids), request.request_id)
return delay_free_blocks, dict(
do_remote_prefill=True,
do_remote_decode=False,
remote_block_ids=computed_block_ids,
remote_engine_id=self.engine_id,
remote_host=self.local_ip,
remote_port=self.port,
remote_tp_size=str(
self.vllm_config.parallel_config.tensor_parallel_size),
)
class LLMDataDistCMgrConnectorWorker():
"""
Implementation of Worker side methods
"""
def __init__(self, vllm_config: VllmConfig):
assert vllm_config.kv_transfer_config is not None
logger.info("Initialize the LLMDataDistCMgrConnectorWorker")
# we assume the local node only contains dp and tp, and tp will not communicate inter-node.
# for any scenario beyond this scope, the functionality of this connector is not guaranteed.
self.local_rank_on_node = get_world_group().rank % (
vllm_config.parallel_config.data_parallel_size_local *
vllm_config.parallel_config.tensor_parallel_size)
self.local_rank = get_world_group().local_rank
self.local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
self.tp_rank = get_tp_group().rank_in_group
self.rank = get_world_group().rank
self.local_ip = get_ip()
self.kv_transfer_config: KVTransferConfig = vllm_config.kv_transfer_config
self.local_agent_metadata: Optional[
LLMDataDistCMgrAgentMetadata] = None
self.vllm_config = vllm_config
self.executor = ThreadPoolExecutor(1)
self.thread_lock = threading.Lock()
self.llm_datadist_role = None
self.llm_datadist_remote_role = None
if self.kv_transfer_config.kv_role == "kv_producer":
self.llm_datadist_role = LLMRole.PROMPT
self.llm_datadist_remote_role = LLMRole.DECODER
elif self.kv_transfer_config.kv_role == "kv_consumer":
self.llm_datadist_role = LLMRole.DECODER
self.llm_datadist_remote_role = LLMRole.PROMPT
else:
raise RuntimeError(
f"LLMDataDistWorker: Receive unexpected kv role in LLMDataDistWorker, this worker now only support kv_producer and kv_consumer, but receiving {vllm_config.kv_transfer_config.kv_role}"
)
# linked_cluster record the cluster that already build the connection its format should be {"cluster_id": "comm_name"}
self.linked_cluster: dict[Any, Any] = {}
self.prefill_device_list: list[tuple[int, int]] = []
self.decode_device_list: list[tuple[int, int]] = []
global_rank_table = self.read_offline_rank_table()
self.local_agent_metadata = self.read_agent_metadata(
global_rank_table, self.local_ip, self.local_rank_on_node,
self.llm_datadist_role)
self.llm_datadist = LLMDataDist(self.llm_datadist_role,
self.local_agent_metadata.cluster_id)
self.init_llm_datadist()
self.finished_reqs: set[str] = set()
self.soc_info = NPUSocInfo()
# Set hccl deterministic for model execute
os.environ["HCCL_DETERMINISTIC"] = "true"
self.done_receiving_counts: defaultdict[str,
set[int]] = defaultdict(set)
def listen_for_agent_metadata_req(self, event: threading.Event):
assert self.local_agent_metadata is not None
port = envs.VLLM_LLMDD_RPC_PORT + self.local_dp_rank * self.tp_size + self.tp_rank if self.local_dp_rank is not None else envs.VLLM_LLMDD_RPC_PORT + self.tp_size + self.tp_rank
url = f"tcp://0.0.0.0:{port}"
msg_encoder = msgspec.msgpack.Encoder()
msg_decoder = msgspec.msgpack.Decoder()
msg_to_send = msg_encoder.encode(self.local_agent_metadata)
logger.debug(f"Start to listen to address: {url}")
logger.debug(
f"The local agent metadata have {len(msg_to_send)} bytes here")
logger.info(
f"LLMDataDistCMgrConnectorWorker: Cluster {self.local_agent_metadata.cluster_id} start to listen request from peers"
)
with zmq_ctx(zmq.ROUTER, url) as sock: # type: ignore[attr-defined]
event.set()
while True:
identity, _, msg = sock.recv_multipart()
event_msg, decode_msg = msg_decoder.decode(msg)
event_msg = LLMDataDistCMgrEvent(event_msg)
if event_msg == LLMDataDistCMgrEvent.ReqForMetadata:
if "cluster_id" in decode_msg:
decode_msg = LLMDataDistCMgrAgentMetadata(**decode_msg)
logger.info(
f"LLMDataDistCMgrConnectorWorker: Receive message from cluster {decode_msg.cluster_id}"
)
sock.send_multipart((identity, b"", msg_to_send))
self.add_remote_agent(decode_msg)
else:
logger.warning(
f"LLMDataDistCMgrConnectorWorker: receiving unrecognized data {decode_msg}"
)
elif event_msg == LLMDataDistCMgrEvent.ReqForFinished:
finished_req_id = decode_msg[0]
decode_tp_rank = decode_msg[1]
decode_tp_size = decode_msg[2]
with self.thread_lock:
if self._increment_task_count(finished_req_id,
decode_tp_rank,
decode_tp_size):
logger.debug(
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
)
self.finished_reqs.add(finished_req_id)
sock.send_multipart(
(identity, b"", b"receiving decode finished"))
else:
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Receiving unexpected request event {event_msg} from remote !"
)
def _increment_task_count(self, request_id: str, tp_rank: int,
decode_tp_size: int):
if request_id not in self.done_receiving_counts:
self.done_receiving_counts[request_id] = set()
if tp_rank in self.done_receiving_counts[request_id]:
logger.warning(
f"Received duplicate done signal for request {request_id} "
f"from tp rank {tp_rank}. Ignoring.")
return False
self.done_receiving_counts[request_id].add(tp_rank)
if len(self.done_receiving_counts[request_id]) == decode_tp_size:
self.done_receiving_counts.pop(request_id)
logger.info("All transfers completed for request: "
f"{request_id}. Total ranks: "
f"{decode_tp_size}.")
return True
return False
def init_llm_datadist(self):
assert self.local_agent_metadata is not None
llm_config = LLMConfig()
llm_config.device_id = self.local_rank
llm_config.sync_kv_timeout = 20000
llm_config.enable_switch_role = True
llm_config.enable_cache_manager = True
llm_config.enable_remote_cache_accessible = True
llm_config_options = llm_config.generate_options()
self.llm_datadist.init(llm_config_options)
self.cache_manager = self.llm_datadist.cache_manager
logger.info(
f"Done initialize llm_datadist in rank {self.rank}, local rank {self.local_rank}, cluster id {self.local_agent_metadata.cluster_id}"
)
def read_offline_rank_table(self):
assert (
envs.DISAGGREGATED_PREFILL_RANK_TABLE_PATH
), "Please set path of rank_table to env variable DISAGGREGATED_PREFILL_RANK_TABLE_PATH"
rank_table_path = envs.DISAGGREGATED_PREFILL_RANK_TABLE_PATH
with open(rank_table_path, "r", encoding="utf-8") as f:
global_rank_table = json.load(f)
decode_device_list = global_rank_table["decode_device_list"]
for decode_device in decode_device_list:
server_id = decode_device["server_id"]
device_id = decode_device["device_id"]
self.decode_device_list.append((server_id, device_id))
prefill_device_list = global_rank_table["prefill_device_list"]
for prefill_device in prefill_device_list:
server_id = prefill_device["server_id"]
device_id = prefill_device["device_id"]
self.prefill_device_list.append((server_id, device_id))
# global_rank_table = json.dumps(global_rank_table)
return global_rank_table
def read_agent_metadata(self, global_rank_table, server_id, device_rank,
agent_role):
devices_type_list = []
agent_metadata = None
if self.llm_datadist_role == LLMRole.PROMPT:
devices_type_list.append("prefill_device_list")
elif self.llm_datadist_role == LLMRole.DECODER:
devices_type_list.append("decode_device_list")
else:
devices_type_list.append("prefill_device_list")
devices_type_list.append("decode_device_list")
for device_type in devices_type_list:
device_list = global_rank_table[device_type]
device_list = [
d for d in device_list if d.get("server_id") == server_id
]
if len(device_list) <= device_rank:
continue
device_info = device_list[device_rank]
super_pod_id_ = device_info.get("super_pod_id", None)
server_id_ = device_info["server_id"]
device_id_ = device_info["device_id"]
device_ip_ = device_info["device_ip"]
super_device_id_ = device_info.get("super_device_id", None)
cluster_id_ = int(device_info["cluster_id"])
agent_metadata = LLMDataDistCMgrAgentMetadata(
super_pod_id=super_pod_id_,
server_id=server_id_,
device_id=device_id_,
device_ip=device_ip_,
super_device_id=super_device_id_,
cluster_id=cluster_id_,
)
assert agent_metadata is not None, f"Can't read the target server_id {server_id} and device_rank {device_rank} from rank table"
return agent_metadata
def register_kv_caches(self, kv_caches: dict[str, Tuple[torch.Tensor]]):
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
first_kv_cache = first_kv_cache_tuple[0]
assert len(first_kv_cache_tuple) > 1
assert self.local_agent_metadata is not None
kv_cache_dtype = first_kv_cache.dtype
self.use_mla: bool = first_kv_cache_tuple[0].size(
-1) != first_kv_cache_tuple[1].size(-1)
# MLA case. [2 (k_normed, k_pe), num_blocks, ...]
# MHA case. [2 (k and v), num_blocks, ...]
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
block_shape = first_kv_cache.shape[-block_rank:]
self.block_len = math.prod(block_shape)
self.cache_addr: list[int] = []
alignment = 2 * 1024 * 1024
if self.use_mla:
cache_k_normed_addr_list = []
cache_k_pe_addr_list = []
k_normed = None
k_pe = None
for cache_or_caches in kv_caches.values():
assert len(cache_or_caches) > 1
k_normed, k_pe = cache_or_caches[0], cache_or_caches[1]
cache_k_normed_addr_list.append(k_normed.data_ptr())
cache_k_pe_addr_list.append(k_pe.data_ptr())
self.cache_addr = (cache_k_normed_addr_list, cache_k_pe_addr_list)
cache_desc_k_normed = CacheDesc(
len(self.cache_addr[0]), [*k_normed.shape],
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
cache_desc_k_pe = CacheDesc(
len(self.cache_addr[1]), [*k_pe.shape],
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
cache_key_k_normed = BlocksCacheKey(cluster_id=int(
self.local_agent_metadata.cluster_id),
model_id=0)
cache_key_k_pe = BlocksCacheKey(cluster_id=int(
self.local_agent_metadata.cluster_id),
model_id=1)
self.cache_desc = (cache_desc_k_normed, cache_desc_k_pe)
self.cache_key = (cache_key_k_normed, cache_key_k_pe)
try:
cache_k_normed = self.cache_manager.register_blocks_cache(
self.cache_desc[0], self.cache_addr[0], self.cache_key[0])
cache_k_pe = self.cache_manager.register_blocks_cache(
self.cache_desc[1], self.cache_addr[1], self.cache_key[1])
self.cache = (cache_k_normed, cache_k_pe)
logger.info("LLMDataDistWorker: End of register Paged Cache.")
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
)
else:
for cache_or_caches in kv_caches.values():
for cache in cache_or_caches:
base_addr = cache.data_ptr()
assert base_addr % alignment == 0, "The address of the registered kv cache should be aligned to 2M"
self.cache_addr.append(base_addr)
# register paged kv cache into the llm_cache manager
self.cache_desc = CacheDesc(
len(self.cache_addr), [*cache.shape],
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
self.cache_key = BlocksCacheKey(
cluster_id=int(self.local_agent_metadata.cluster_id))
logger.info(
f"num of cache: {len(self.cache_addr)}, size of cache: {[*cache.shape]}, real size of cache: {first_kv_cache.shape}"
)
try:
self.cache = self.cache_manager.register_blocks_cache(
self.cache_desc, self.cache_addr, self.cache_key)
logger.info(
"LLMDataDistCMgrConnectorWorker: End of register Paged Cache."
)
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
)
self.ready_event = threading.Event()
self.metadata_agent_listener_t = threading.Thread(
target=self.listen_for_agent_metadata_req,
args=(self.ready_event, ),
daemon=True,
name="metadata_agent_listener")
self.metadata_agent_listener_t.start()
self.ready_event.wait()
def start_load_kv(self, metadata: LLMDataDistCMgrConnectorMetadata):
futures = []
for req_id, meta in metadata.requests.items():
logger.debug(f"Start to transmit {req_id}")
future = self.executor.submit(
self._read_blocks,
local_block_ids=meta.local_block_ids,
remote_block_ids=meta.remote_block_ids,
remote_ip=meta.remote_host,
remote_port=int(meta.remote_port),
remote_engine_id=meta.engine_id,
request_id=req_id,
remote_tp_size=meta.remote_tp_size,
)
futures.append(future)
def handle_exception(future):
if future.exception():
logger.error(f"KV transfer task failed: {future.exception()}")
for future in futures:
future.add_done_callback(handle_exception)
def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int:
assert self.local_agent_metadata is not None
remote_cluster_id = metadata.cluster_id
if remote_cluster_id in self.linked_cluster:
logger.debug(
f"LLMDataDistCMgrConnectorWorker: remote cluster_id: {metadata.cluster_id} already linked with this server, skip the connection"
)
return remote_cluster_id
remote_super_pod_id = metadata.super_pod_id
remote_server_id = metadata.server_id
is_same_server = remote_server_id == self.local_agent_metadata.server_id
is_same_pod = remote_super_pod_id == self.local_agent_metadata.super_pod_id
if self.llm_datadist_role == LLMRole.PROMPT:
prefill_metadata = self.local_agent_metadata
decode_metadata = metadata
else:
prefill_metadata = metadata
decode_metadata = self.local_agent_metadata
comm_name = f"pd_comm_{prefill_metadata.device_ip}_{decode_metadata.device_ip}"
cluster_rank_info = {
prefill_metadata.cluster_id: 0,
decode_metadata.cluster_id: 1
}
rank_table = {}
rank_table["version"] = "1.2"
rank_table["server_count"] = "1" if is_same_server else "2"
rank_table["status"] = "completed"
# generate server_list for rank table
rank_table["server_list"] = [] # type: ignore[assignment]
decode_server_device_info = None
prefill_server_device_info = {
"device": [{
k: v
for k, v in [(
"device_id", prefill_metadata.device_id
), ("device_ip", prefill_metadata.device_ip
), ("super_device_id",
prefill_metadata.super_device_id), ("rank_id", "0")]
if v is not None
}],
"server_id":
prefill_metadata.server_id
}
if is_same_server:
prefill_server_device_info["device"].append( # type: ignore[attr-defined]
{
k: v
for k, v in [(
"device_id", decode_metadata.device_id
), ("device_ip", decode_metadata.device_ip
), ("super_device_id",
decode_metadata.super_device_id), ("rank_id", "1")]
if v is not None
})
else:
decode_server_device_info = {
"device": [{
k: v
for k, v in [(
"device_id", decode_metadata.device_id
), ("device_ip", decode_metadata.device_ip
), ("super_device_id",
decode_metadata.super_device_id), ("rank_id", "1")]
if v is not None
}],
"server_id":
decode_metadata.server_id
}
rank_table["server_list"].append( # type: ignore[attr-defined]
prefill_server_device_info)
if decode_server_device_info is not None:
rank_table["server_list"].append( # type: ignore[attr-defined]
decode_server_device_info)
if self.soc_info.is_a3:
# generate super_pod_list for rank table
super_pod_list = []
prefill_super_pod_info = {
"super_pod_id": prefill_metadata.super_pod_id,
"server_list": [{
"server_id": prefill_metadata.server_id
}],
}
if is_same_pod and not is_same_server:
prefill_super_pod_info[
"server_list"].append( # type: ignore[attr-defined]
{"server_id": decode_metadata.server_id})
super_pod_list.append(prefill_super_pod_info)
if not is_same_pod:
decode_super_pod_id = {
"super_pod_id": decode_metadata.super_pod_id,
"server_list": [{
"server_id": decode_metadata.server_id
}],
}
super_pod_list.append(decode_super_pod_id)
rank_table[
"super_pod_list"] = super_pod_list # type: ignore[assignment]
logger.info(
f"LLMDataDistCMgrConnectorWorker: try link with remote, comm id: {comm_name}"
)
logger.info(f"rank table \n{rank_table}")
logger.info(f"comm name: {comm_name}")
logger.info(f"cluster rank info: {cluster_rank_info}")
comm_id = self.llm_datadist.link(comm_name, cluster_rank_info,
json.dumps(rank_table))
while True:
ret = self.llm_datadist.query_register_mem_status(comm_id=comm_id)
if ret == llm_datadist.RegisterMemStatus.OK:
logger.info(
f"LLMDataDistCMgrConnectorWorker: Linking success, comm id: {comm_id}"
)
break
elif ret == llm_datadist.RegisterMemStatus.FAILED:
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Linking failed, comm id: {comm_id}"
)
time.sleep(1)
logger.info("Checking query_register_mem_status again")
self.linked_cluster.update({remote_cluster_id: comm_id})
logger.info(f"cached linked cluster: {self.linked_cluster}")
logger.info(
f"Successfully build link with cluster id {remote_cluster_id} with cluster name {comm_name} !"
)
return remote_cluster_id
def remove_remote_agent(self, cluster_id: int):
if cluster_id not in self.linked_cluster:
logger.warning(
f"LLMDataDistCMgrConnectorWorker: Warning! Can't remove remote client with cluster id {cluster_id} for its not exist in linked_cluster list"
)
comm_id = self.linked_cluster[cluster_id]
try:
self.llm_datadist.unlink(comm_id)
self.linked_cluster.pop(cluster_id)
except LLMException:
logger.error(
f"Try to remove remote client with cluster id {cluster_id} failed!, program won't terminate, but please carefully check your environment"
)
logger.info(
f"Successfully remove remote client with cluster id {cluster_id} !"
)
def connect_to_remote_agent(self, host: str, port: int) -> int:
url = f"tcp://{host}:{port}"
logger.debug(f"Querying metadata from url: {url}")
msg_encoder = msgspec.msgpack.Encoder()
msg_send = msg_encoder.encode(
[LLMDataDistCMgrEvent.ReqForMetadata, self.local_agent_metadata])
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
logger.info("Try request remote metadata from socket......")
sock.send(msg_send)
metadata_bytes = sock.recv()
decoder = msgspec.msgpack.Decoder()
metadata = decoder.decode(metadata_bytes)
metadata = LLMDataDistCMgrAgentMetadata(**metadata)
logger.info(f"recving metadata: {metadata}")
cluster_id = self.add_remote_agent(metadata)
return cluster_id
def send_finish_to_remote(self, host: str, port: int, request_id):
url = f"tcp://{host}:{port}"
logger.debug(f"Sending finished to remote: {url}")
msg_encoder = msgspec.msgpack.Encoder()
msg_send = msg_encoder.encode([
LLMDataDistCMgrEvent.ReqForFinished,
[request_id, self.tp_rank, self.tp_size]
])
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
try:
sock.send(msg_send)
logger.debug(
f"Request id {request_id} finished message send to remote {url}"
)
_ = sock.recv()
except Exception as e:
logger.error(
f"Failed to send reqest_id {request_id} to prefill: {e}")
def _read_blocks(
self,
local_block_ids: list[int],
remote_block_ids: list[int],
remote_ip: str,
remote_port: int,
remote_engine_id: str,
request_id: str,
remote_tp_size: str,
):
# if remote_ip not in self.linked_cluster:
tp_offset = self.tp_rank % int(remote_tp_size)
remote_cluster_id = self.connect_to_remote_agent(
remote_ip, remote_port + tp_offset)
num_local_blocks = len(local_block_ids)
if num_local_blocks == 0:
return
num_remote_blocks = len(remote_block_ids)
assert num_local_blocks <= num_remote_blocks
if num_local_blocks < num_remote_blocks:
remote_block_ids = remote_block_ids[-num_local_blocks:]
logger.info(f"remote cluster id is: {remote_cluster_id}")
if self.use_mla:
remote_cache_key_k_normed = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=0)
remote_cache_key_k_pe = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=1)
logger.info("Try pull blocks from remote server")
try:
self.cache_manager.pull_blocks(
remote_cache_key_k_normed,
self.cache[0], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
self.cache_manager.pull_blocks(
remote_cache_key_k_pe,
self.cache[1], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
)
except LLMException:
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
else:
remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id)
logger.info("Try pull blocks from remote server")
try:
self.cache_manager.pull_blocks(
remote_cache_key,
self.cache, # type: ignore[has-type]
remote_block_ids,
local_block_ids)
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
)
except LLMException:
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
self.send_finish_to_remote(remote_ip, remote_port, request_id)
with self.thread_lock:
self.finished_reqs.add(request_id)
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""Get the finished recving and sending requuests."""
import copy
with self.thread_lock:
req_ids_to_ret = copy.deepcopy(self.finished_reqs)
self.finished_reqs.clear()
if self.llm_datadist_role == LLMRole.PROMPT:
return req_ids_to_ret, None
else:
return None, req_ids_to_ret
# adopt this from https://github.com/vllm-project/vllm/blob/main/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
@contextlib.contextmanager
def zmq_ctx(socket_type: Any,
addr: str) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
"""Context manager for a ZMQ socket"""
ctx: Optional[zmq.Context] = None # type: ignore[name-defined]
try:
ctx = zmq.Context() # type: ignore[attr-defined]
if socket_type == zmq.ROUTER: # type: ignore[attr-defined]
socket = ctx.socket(zmq.ROUTER) # type: ignore[attr-defined]
socket.bind(addr)
elif socket_type == zmq.REQ: # type: ignore[attr-defined]
socket = ctx.socket(zmq.REQ) # type: ignore[attr-defined]
socket.connect(addr)
else:
raise ValueError(f"Unexpected socket type: {socket_type}")
yield socket
finally:
if ctx is not None:
ctx.destroy(linger=0)

View File

@@ -133,6 +133,28 @@ env_variables: Dict[str, Callable[[], Any]] = {
"VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION":
lambda: bool(
int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION", '0'))),
# `LLMDataDistCMgrConnector` required variable. `DISAGGREGATED_PREFILL_RANK_TABLE_PATH` is
# used for llmdatadist to build the communication topology for kv cache transfer, it is
# a required variable if `LLMDataDistCMgrConnector` is used as kv connector for disaggregated
# pd. The rank table can be generated by adopting the script `gen_ranktable.sh`
# in vllm_ascend's example folder.
"DISAGGREGATED_PREFILL_RANK_TABLE_PATH":
lambda: os.getenv("DISAGGREGATED_PREFILL_RANK_TABLE_PATH", None),
# `LLMDataDistCMgrConnector` required variable. `VLLM_ASCEND_LLMDD_RPC_IP` is used as the
# rpc communication listening ip, which will be used to receive the agent metadata from the
# remote worker.
"VLLM_ASCEND_LLMDD_RPC_IP":
lambda: os.getenv("VLLM_ASCEND_LLMDD_RPC_IP", "0.0.0.0"),
# `LLMDataDistCMgrConnector` required variable. `VLLM_LLMDD_RPC_PORT` is used as the
# rpc communication listening port, which will be used to receive the agent metadata from the
# remote worker.
"VLLM_LLMDD_RPC_PORT":
lambda: int(os.getenv("VLLM_LLMDD_RPC_PORT", 5557)),
# Whether to enable mla_pa for deepseek mla decode, this flag will be removed after its available torch_npu is public accessible
# and the mla_pa will be the default path of deepseek decode path.
"VLLM_ASCEND_MLA_PA":
lambda: int(os.getenv("VLLM_ASCEND_MLA_PA", 0))
}
# end-env-vars-definition

View File

@@ -32,7 +32,8 @@ import torch_npu
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config)
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group, split_tensor_along_last_dim,
@@ -363,6 +364,10 @@ class CustomDeepseekV2MoE(nn.Module):
self.tp_group = get_tp_group().device_group
self.tp_rank = get_tp_group().rank_in_group
self.ep_group = get_ep_group()
self.kv_consumer = None
transfer_config = get_current_vllm_config().kv_transfer_config
if transfer_config is not None:
self.kv_consumer = transfer_config.kv_role == "kv_consumer"
self.params_dtype = torch.get_default_dtype()
self.rm_router_logits = self.experts.rm_router_logits
@@ -386,6 +391,11 @@ class CustomDeepseekV2MoE(nn.Module):
enable_force_load_balance = False
if hasattr(attn_metadata, 'with_prefill_across_dp'):
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
# If this node is kv_consumer, we force the moe always runs in decode path to make sure
# the behaviour aligned between dummy_run and normal model_execute.
if self.kv_consumer:
is_prefill = False
enable_force_load_balance = False
# router_logits: (num_tokens, n_experts)
router_logits = None

View File

@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
from typing import List, Optional, Tuple
import torch
from vllm.model_executor.layers.linear import ColumnParallelLinear
@@ -37,7 +37,7 @@ def vanilla_chunked_prefill(
scale: float,
alibi_slopes: Optional[torch.Tensor],
causal: bool = True,
) -> None:
) -> torch.Tensor:
num_query_heads = query.shape[1]
head_dim = value_cache.shape[3]
num_kv_heads = value_cache.shape[2]
@@ -138,7 +138,8 @@ def vanilla_chunked_prefill(
def vanilla_chunked_prefill_mla(
output: torch.Tensor, # (num_tokens, num_heads, v_head_dim)
query: torch.Tensor, # (num_tokens, num_heads, nope_dim + rope_dim)
kv_cache: torch.Tensor, # (num_blocks, block_size, latent_kv)
kv_cache: Tuple[
torch.Tensor], # [nope, rope] (num_blocks, block_size, latent_kv)
block_tables: torch.Tensor, # (batch_size, max_num_blocks_per_seq)
query_lens: torch.Tensor, # (batch_size)
context_lens: torch.Tensor, # (batch_size)
@@ -152,22 +153,25 @@ def vanilla_chunked_prefill_mla(
alibi_slopes: Optional[torch.Tensor],
causal: bool = True) -> None:
batch_size = block_tables.size(0)
assert len(kv_cache) > 1
assert query_lens.size(0) == batch_size
num_heads = query.size(1)
block_size = kv_cache.size(1)
latent_kv_dim = kv_cache.size(3) - rope_dim
nope_cache = kv_cache[0]
rope_cache = kv_cache[1]
block_size = nope_cache.size(1)
latent_kv_dim = nope_cache.size(-1)
max_num_blocks_per_seq = block_tables.size(1)
batch_size = query_lens.size(0)
kv_cache = kv_cache.squeeze()
# select kv_c out as [batch_size, max_context_len, latent_kv + rope_dim]
cache_kv_c_pe = kv_cache[block_tables].view(
batch_size, max_num_blocks_per_seq * block_size,
latent_kv_dim + rope_dim)[:, :max_context_len, :]
# get kv_c and k_pe
nope_cache = nope_cache.squeeze()
# select kv_c out as [batch_size, max_context_len, latent_kv + rope_dim] and get kv_c and k_pe
# cached_kv_c: [batch_size, max_context_len, latent_kv]
# cached_k_pe: [batch_size, max_context_len, rope_dim]
cache_kv_c = cache_kv_c_pe[:, :, :latent_kv_dim]
cache_k_pe = cache_kv_c_pe[:, :, latent_kv_dim:]
cache_kv_c = nope_cache[block_tables].view(
batch_size, max_num_blocks_per_seq * block_size,
latent_kv_dim)[:, :max_context_len, :]
cache_k_pe = rope_cache[block_tables].view(
batch_size, max_num_blocks_per_seq * block_size,
rope_dim)[:, :max_context_len, :]
# get k_rope and v
# k_nope: [batch_size, max_context_len, num_heads, nope_dim]
# value: [batch_size, max_context_len, num_heads, v_head_dim]
@@ -258,8 +262,8 @@ def vanilla_chunked_prefill_mla(
attn_output = (attn_output[q_mask].view([-1, num_heads,
v_head_dim]).to(output.dtype))
output = output.view([-1, num_heads, v_head_dim])
output.copy_(attn_output[:query.size(0) - num_add_query])
attn_output = attn_output.view_as(output)
output.copy_(attn_output)
return attn_output

View File

@@ -122,7 +122,10 @@ def fused_experts_with_mc2(
if log2phy is not None:
topk_ids = log2phy[topk_ids]
global_bs = 0
moe_expert_num = len(expert_map) + global_redundant_expert_num
if (expert_map is not None):
moe_expert_num = len(expert_map) + global_redundant_expert_num
else:
moe_expert_num = global_redundant_expert_num
# hidden_states = hidden_states.bfloat16()
kwargs_mc2 = {
"x": hidden_states,

14
vllm_ascend/soc_info.py Normal file
View File

@@ -0,0 +1,14 @@
from dataclasses import dataclass
import torch_npu
@dataclass
class NPUSocInfo:
is_a3: bool = False
def __post_init__(self):
torch_npu.npu._lazy_init()
self.soc_version = torch_npu._C._npu_get_soc_version()
if self.soc_version in (250, 251, 252, 253, 254, 255):
self.is_a3 = True

View File

@@ -17,7 +17,9 @@
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
#
import copy
import gc
import math
import os
import time
import types
@@ -37,9 +39,12 @@ from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
get_tp_group)
from vllm.forward_context import set_forward_context
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import logger
from vllm.model_executor.layers.fused_moe import FusedMoE
@@ -342,6 +347,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
torch._logging.set_logs(
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
# kv role
self.is_kv_producer = False
if vllm_config.kv_transfer_config is not None:
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
"""Update the cached states and the persistent batch with the scheduler
output.
@@ -908,7 +918,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> tuple[Union[AscendMetadata, AscendMLAMetadata,
AscendTorchairMetadata], torch.Tensor, SpecDecodeMetadata,
torch.Tensor, int, torch.Tensor, torch.Tensor, np.ndarray]:
torch.Tensor, int, torch.Tensor, torch.Tensor, np.ndarray,
Optional[set[str]], Optional[set[str]]]:
# Check input valid
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
@@ -1144,6 +1155,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.vllm_config,
num_tokens=num_input_tokens):
with ProfileExecuteDuration().capture_async("forward"):
self.maybe_setup_kv_connector(scheduler_output)
model_kwargs = {}
if self.torchair_graph_enabled:
model_kwargs["kv_caches"] = self.kv_caches
@@ -1174,6 +1186,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
**model_kwargs,
)
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = self.get_finished_kv_transfer(
scheduler_output)
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
if not use_spec_decode:
@@ -1203,7 +1218,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return (attn_metadata, hidden_states, spec_decode_metadata, positions,
total_num_scheduled_tokens, logits_indices, aux_hidden_states,
num_scheduled_tokens)
num_scheduled_tokens, finished_sending, finished_recving)
def _get_cumsum_and_arange(
self,
@@ -1436,12 +1451,18 @@ class NPUModelRunner(LoRAModelRunnerMixin):
"prepare input and forward"):
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
# Return empty ModelRunnerOuptut if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
if not has_kv_transfer_group():
logger.debug(
"skip this step for we receive the data from remote disaggregate prefill node"
)
# Return empty ModelRunnerOuptut if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output)
(attn_metadata, hidden_states, spec_decode_metadata, positions,
num_scheduled_tokens, logits_indices, aux_hidden_states,
num_scheduled_tokens_np) = (self._process_reqs(
scheduler_output, intermediate_tensors))
num_scheduled_tokens_np, finished_sending,
finished_recving) = (self._process_reqs(scheduler_output,
intermediate_tensors))
with ProfileExecuteDuration().capture_async("post process"):
# Broadcast PP output for external_launcher (torchrun)
@@ -1593,6 +1614,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states,
)
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
model_runner_output = ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
@@ -1601,6 +1625,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
finished_sending=finished_sending,
finished_recving=finished_recving,
)
durations = ProfileExecuteDuration().pop_captured_sync()
@@ -1615,6 +1641,49 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return model_runner_output
def kv_connector_no_forward(
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
with set_forward_context(None, self.vllm_config):
self.maybe_setup_kv_connector(scheduler_output)
finished_sending, finished_recving = (
self.get_finished_kv_transfer(scheduler_output))
# For the case of no forward caused by receiving remote kv,
# one round of dummy inference is necessary
# to prevent hang over the collective calls.
if not finished_sending and not finished_recving:
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.finished_sending = finished_sending
output.finished_recving = finished_recving
return output
@staticmethod
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
# Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group():
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
assert scheduler_output.kv_connector_metadata is not None
kv_connector.bind_connector_metadata(
scheduler_output.kv_connector_metadata)
kv_connector.start_load_kv(get_forward_context())
@staticmethod
def maybe_wait_for_kv_save() -> None:
if has_kv_transfer_group():
get_kv_transfer_group().wait_for_save()
@staticmethod
def get_finished_kv_transfer(
scheduler_output: "SchedulerOutput",
) -> tuple[Optional[set[str]], Optional[set[str]]]:
if has_kv_transfer_group():
return get_kv_transfer_group().get_finished(
scheduler_output.finished_req_ids)
return None, None
@torch.inference_mode()
def _dummy_run(
self,
@@ -1633,6 +1702,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32)
# Force dummy run on prefill stage when this node is deemed as kv producer.
if self.is_kv_producer:
with_prefill = True
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
model = self.model
@@ -1899,9 +1972,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.kv_cache_config = kv_cache_config
import torch_npu
acl_format = ACL_FORMAT_FRACTAL_NZ if is_310p(
) else ACL_FORMAT_FRACTAL_ND
) and not self.torchair_graph_enabled else ACL_FORMAT_FRACTAL_ND
kv_caches: Dict[str, torch.Tensor] = {}
def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
data_ptr = tensor.data_ptr()
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
offset = (aligned_addr - data_ptr) // tensor.element_size()
return tensor[int(offset):]
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.model_config.max_model_len,
@@ -1935,6 +2014,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# different GPUs, and `kv_cache_config.num_blocks` is set to
# the min of all `num_blocks`. Verify it here.
assert num_blocks >= kv_cache_config.num_blocks
alignment = 2 * 1024 * 1024
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
# encounter OOM issue
if isinstance(kv_cache_spec, FullAttentionSpec):
@@ -1949,58 +2029,78 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size)
if self.torchair_graph_enabled:
if len(kv_cache_shape) == 3:
# for non MLA attention backend that use torchair, we consider to pass kv_cache layout
# of BSH ([num_blocks, block_size, kv_head_dim * head_size]) to attention.
dtype = kv_cache_spec.dtype
if self.model_config.is_deepseek_mla:
kv_caches[layer_name] = (
torch.zeros(kv_cache_shape,
dtype=self.kv_cache_dtype,
device=self.device),
torch.zeros(kv_cache_shape,
dtype=self.kv_cache_dtype,
device=self.device))
# atb reshape_and_cache does not support torchair.
kv_caches[layer_name] = (
torch_npu.npu_format_cast(
kv_caches[layer_name][0],
ACL_FORMAT_FRACTAL_ND),
torch_npu.npu_format_cast(
kv_caches[layer_name][1],
ACL_FORMAT_FRACTAL_ND),
)
num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
nope_dim = head_size - rope_dim
nope_cache_shape = (num_blocks, block_size,
num_kv_heads, nope_dim)
rope_cache_shape = (num_blocks, block_size,
num_kv_heads, rope_dim)
if self.vllm_config.kv_transfer_config is None:
# For no disaggregate pd scenario, allocate kv cache in normal way
rope_cache = torch.zeros(rope_cache_shape,
dtype=dtype,
device=self.device)
nope_cache = torch.zeros(nope_cache_shape,
dtype=dtype,
device=self.device)
rope_cache = torch_npu.npu_format_cast(
rope_cache, acl_format)
nope_cache = torch_npu.npu_format_cast(
nope_cache, acl_format)
else:
# for MLA attention backend that use torchair.
layer_kv_cache_nope = torch.zeros(
kv_cache_shape[:-1] +
(self.model_config.hf_text_config.kv_lora_rank,
),
dtype=self.dtype,
pin_memory=True,
# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
# address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but
# we found there are also some exceptions during test, so we manual align those memory here, this part
# of code may consume 2M * 2 * elem_size memory every layer.
nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim
nope_allocate_shape_alignment = nope_allocate_shape + alignment
rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim
rope_allocate_shape_alignment = rope_allocate_shape + alignment
nope_cache = torch.zeros(
nope_allocate_shape_alignment,
dtype=dtype,
device=self.device)
layer_kv_cache_pe = torch.zeros(
kv_cache_shape[:-1] +
(self.model_config.hf_text_config.
qk_rope_head_dim, ),
dtype=self.dtype,
pin_memory=True,
rope_cache = torch.zeros(
rope_allocate_shape_alignment,
dtype=dtype,
device=self.device)
kv_caches[layer_name] = (layer_kv_cache_nope,
layer_kv_cache_pe)
kv_caches[layer_name] = (
torch_npu.npu_format_cast(
kv_caches[layer_name][0], acl_format),
torch_npu.npu_format_cast(
kv_caches[layer_name][1], acl_format),
)
nope_cache = align_memory(
nope_cache,
alignment)[:nope_allocate_shape].view(
nope_cache_shape)
rope_cache = align_memory(
rope_cache,
alignment)[:rope_allocate_shape].view(
rope_cache_shape)
kv_caches[layer_name] = (nope_cache, rope_cache)
else:
kv_caches[layer_name] = torch.zeros(
kv_cache_shape,
dtype=self.kv_cache_dtype,
device=self.device)
kv_caches[layer_name] = \
torch_npu.npu_format_cast(kv_caches[layer_name], acl_format)
num_caches = kv_cache_shape[0]
kv_cache_list = []
for i in range(num_caches):
cache_shape = kv_cache_shape[1:]
if self.vllm_config.kv_transfer_config is None:
kv_cache = torch.zeros(cache_shape,
dtype=dtype,
device=self.device)
kv_cache = torch_npu.npu_format_cast(
kv_cache, acl_format)
else:
cache_size = math.prod(cache_shape)
cache_size_aligned = cache_size + alignment
kv_cache = torch.zeros(cache_size_aligned,
dtype=dtype,
device=self.device)
kv_cache = align_memory(
kv_cache,
alignment)[:cache_size].view(cache_shape)
kv_cache_list.append(kv_cache)
kv_caches[layer_name] = tuple(kv_cache_list)
else:
# TODO: add new branches when introducing more types of
# KV cache specs.
@@ -2011,6 +2111,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)
if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""
Generates the KVCacheSpec by parsing the kv cache format from each

View File

@@ -78,6 +78,9 @@ class NPUWorker(WorkerBase):
is_driver_worker=is_driver_worker)
# Try to import mindie_turbo to accelerate vLLM inference.
local_dp_rank = self.vllm_config.parallel_config.data_parallel_rank_local
world_size = self.vllm_config.parallel_config.world_size
self.local_rank_across_dp = local_dp_rank * world_size + self.local_rank
try_register_lib(
"mindie_turbo",
"MindIE Turbo is installed. vLLM inference will be accelerated with MindIE Turbo."