[Disaggregated Prefill] P2P Disaggregated Prefill based on llm_datadist (#694)

### What this PR does / why we need it?
- This PR proposes a P2P version of Disaggregated Prefill based on
llm_datadist which manages data transfer.

- This solution reconstructs previous offline single-node Disaggregated
Prefill solution, and supports multi-node and online serveing now.

- Currently this solution supports 1P1D situation of Deepseek hybrid
parallelism (P: TP+EP, D: DP+EP). Note that xPyD situation is considered
in the solution design, and will be supported soon within v1 engine.

---------

Signed-off-by: hw_whx <wanghexiang7@huawei.com>
Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
Co-authored-by: hw_whx <wanghexiang7@huawei.com>
Co-authored-by: ganyi <pleaplusone.gy@gmail.com>
This commit is contained in:
whx
2025-05-01 22:31:36 +08:00
committed by GitHub
parent 84e2ed898b
commit 8b194ad12e
18 changed files with 1769 additions and 32 deletions

View File

@@ -136,18 +136,9 @@ jobs:
id: filter_spec_decode
uses: dorny/paths-filter@v3
with:
# speculative decode seems will cause oom issue, disable it now on ci test
filters: |
speculative_tests_changed:
- "tests/singlecard/spec_decode/**"
- "tests/multicard/spec_decode_e2e/**"
- "vllm_ascend/worker/worker.py"
- "vllm_ascend/worker/model_runner.py"
- "vllm_ascend/worker/multi_step_runner.py"
- "vllm_ascend/worker/multi_step_worker.py"
- "vllm_ascend/worker/draft_model_runner.py"
- "vllm_ascend/patch/worker/patch_common/patch_metrics.py"
- "vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py"
- "vllm_ascend/patch/worker/patch_common/patch_multi_step_worker.py"
speculative_tests_changed: 'false'
- name: Run vllm-project/vllm-ascend Speculative Decode test
if: steps.filter_spec_decode.outputs.speculative_tests_changed == 'true' || github.event_name == 'schedule'

View File

@@ -2,12 +2,22 @@
This file demonstrates the example usage of disaggregated prefilling
We will launch 2 vllm instances (NPU 0,1 for prefill and NPU 2,3 for decode),
and then transfer the KV cache between them.
prompy_device_ips denotes device ip of NPU 0,1
decode_device_ips denotes device ip of NPU 2,3
The device ips of all NPUs in current server can be found through
examples/disaggregated_prefill/find_device_ips.py
"""
import multiprocessing as mp
import os
import time
from multiprocessing import Event, Process
kv_connector_extra_config = {
"prompt_device_ips": ["1.2.3.1", "1.2.3.2"],
"decode_device_ips": ["1.2.3.9", "1.2.3.10"],
"llmdatadist_comm_port": 26000,
}
def clean_up():
import gc
@@ -34,11 +44,10 @@ def run_prefill(prefill_done, process_close):
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
ktc = KVTransferConfig.from_cli(
'{"kv_connector":"AscendHcclConnector","kv_buffer_device":"npu","kv_role":"kv_producer", "kv_parallel_size":2}'
'{"kv_connector":"AscendSimpleConnector","kv_buffer_device":"npu","kv_role":"kv_producer", "kv_parallel_size":2}'
)
# Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB
# memory. You may need to adjust the value to fit your GPU.
global kv_connector_extra_config
ktc.kv_connector_extra_config = kv_connector_extra_config
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
kv_transfer_config=ktc,
max_model_len=2000,
@@ -69,15 +78,16 @@ def run_decode(prefill_done):
from vllm.config import KVTransferConfig
prompts = [
"Hello, how are you today?", "Hi, what is your name?",
"Tell me a very long story.", "what is your favourite book?"
"Hello, how are you today?",
"Hi, what is your name?",
]
sampling_params = SamplingParams(temperature=0, top_p=0.95)
ktc = KVTransferConfig.from_cli(
'{"kv_connector":"AscendHcclConnector","kv_buffer_device":"npu","kv_role":"kv_consumer","kv_parallel_size":2}'
'{"kv_connector":"AscendSimpleConnector","kv_buffer_device":"npu","kv_role":"kv_consumer","kv_parallel_size":2}'
)
global kv_connector_extra_config
ktc.kv_connector_extra_config = kv_connector_extra_config
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
kv_transfer_config=ktc,
max_model_len=2000,

View File

@@ -0,0 +1,463 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import copy
import logging
import os
import threading
import time
import uuid
import aiohttp
import msgpack # type: ignore
import zmq
from quart import Quart, make_response, request
DP_PROXY_HTTP_PORT = 10004
DP_PROXY_ZMQ_REG_PORT = 30006
DP_PROXY_ZMQ_NOTIFY_PORT = 30005
PD_PROXY_ADDRESS = "127.0.0.1:30002"
MY_HTTP_ADDRESS = f"127.0.0.1:{DP_PROXY_HTTP_PORT}"
MY_ZMQ_ADDRESS_PLACEHOLDER = f"127.0.0.1:{DP_PROXY_ZMQ_REG_PORT}"
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
TIME_INTERVAL_FOR_IDLE_RUN = 5e-4
DP_SIZE = 2
dp_instances: dict[str, bool] = {}
dp_cv = threading.Condition()
round_robin_index = 0
_idle_send_loop = None
def make_idle_request():
# Same as before
data = {
"prompt": "hi",
"max_tokens": 1,
"temperature": 0,
}
return data
def random_uuid() -> str:
return str(uuid.uuid4().hex)
async def send_idle_token_to_client(schedule_dict):
for key, value in schedule_dict.items():
if value:
continue
request_received_id = random_uuid()
idle_request_data = make_idle_request()
forward_request_id = f"dp_idle_{key}_{request_received_id}"
target_url = f'http://{key}/v1/completions'
logger.debug(
f"DP Decode Proxy: Sending idle token to D node {key} at {target_url}"
)
generator = forward_request_internal(target_url, idle_request_data,
forward_request_id)
try:
async for response in generator:
logger.debug(
f"DP Decode Proxy: Idle Request {request_received_id}: response from {key}, got response: {response}"
)
except Exception as e:
logger.warning(
f"DP Decode Proxy: Error sending idle token to {key}: {e}")
def metadata_collect_trigger(poller, router_socket):
global dp_instances
global dp_cv
global _idle_send_loop
with dp_cv:
dp_cv.wait()
while True:
try:
schedule_dict = copy.deepcopy(dp_instances)
for key in schedule_dict.keys():
schedule_dict[key] = False
first_start = False
start_time = None
while not all(schedule_dict.values()):
if start_time is not None:
time_interval = time.time() - start_time
logger.debug("check time interval: ", time_interval)
if time_interval > TIME_INTERVAL_FOR_IDLE_RUN:
logger.debug(
"exceeds max time interval send idle token to client"
)
# Send idle token to client in case of single dp rank run solo and block on the CCL part
asyncio.run_coroutine_threadsafe(
send_idle_token_to_client(schedule_dict),
_idle_send_loop) # type: ignore
# Note: Reset start time prevent consistently send idle token to client
# We only reset start time here, for some of the client may loss the idle token send from this proxy
# and we only exit this while loop when we make sure all the client are exactly start inference in this
# step
start_time = time.time()
socks = dict(poller.poll(timeout=500)) # timeout in 500ms
if socks:
logger.debug("receive socks from moniter threads: ", socks)
if router_socket in socks:
messages = router_socket.recv_multipart()
try:
# {"info": "notify_step", "http_address": ""}
for message in messages:
data = msgpack.loads(message)
http_addr = None
logger.debug(f"receive message {data}")
if data.get("info") == "notify_step":
http_addr = data.get("http_address")
if http_addr in schedule_dict.keys():
schedule_dict[http_addr] = True
logger.debug("set first time")
if not first_start:
logger.debug("record start time")
first_start = True
start_time = time.time()
else:
logger.warning("Unrecognize http address")
else:
logger.warning(
"Got unrecognize info type! We only accept notify step info yet"
)
except (msgpack.UnpackException, TypeError, KeyError) as e:
logger.error(
f"Error processing message from {http_addr}: {e}. Message: {data}"
)
except zmq.ZMQError as e: # type: ignore
logger.error(f"ZMQ Error in monitor thread: {e}")
if e.errno == zmq.ETERM: # type: ignore
logger.error(
"Monitor thread terminating due to context termination.")
break
time.sleep(1)
except Exception as e:
logger.error(f"Unexpected error in monitor thread: {e}")
import traceback
traceback.print_exc()
time.sleep(1)
def _listen_for_d_register(poller, router_socket):
global dp_instances
global dp_cv
global DP_SIZE
logger.info(
f"DP Decode Proxy: D Node ZMQ Listener started on ROUTER port {DP_PROXY_ZMQ_REG_PORT}"
)
while True:
try:
socks = dict(poller.poll(timeout=1000))
if router_socket in socks:
remote_id, message = router_socket.recv_multipart()
try:
data = msgpack.loads(message)
if data.get("type") == "DP":
http_addr = data.get("http_address")
zmq_addr = data.get("zmq_address")
if http_addr:
with dp_cv:
if http_addr not in dp_instances:
logger.info(
f"DP Decode Proxy: Registering D Node instance: http={http_addr}, zmq={zmq_addr}"
)
dp_instances[http_addr] = True
if len(dp_instances) >= DP_SIZE:
logger.info(
f"DP Decode Proxy: Reached expected D Node count ({DP_SIZE}). Notifying metadata collector."
)
dp_cv.notify_all()
else:
pass
else:
logger.warning(
f"DP Decode Proxy: Received D Node registration from {remote_id.decode()} without http_address. Data: {data}"
)
else:
logger.warning(
f"DP Decode Proxy: Received message with unexpected type from {remote_id.decode()}. Type: {data.get('type')}, Data: {data}"
)
except (msgpack.UnpackException, TypeError, KeyError) as e:
logger.error(
f"DP Decode Proxy: Error processing D Node registration from {remote_id.decode()}: {e}. Message: {message}"
)
except Exception as e:
logger.error(
f"DP Decode Proxy: Unexpected error processing D Node registration from {remote_id.decode()}: {e}"
)
except zmq.ZMQError as e: # type: ignore
logger.error(
f"DP Decode Proxy: ZMQ Error in D Node listener thread: {e}")
if e.errno == zmq.ETERM: # type: ignore
logger.info(
"DP Decode Proxy: D Node Listener thread terminating.")
break
time.sleep(1)
except Exception as e:
logger.error(
f"DP Decode Proxy: Unexpected error in D Node listener thread: {e}"
)
import traceback
traceback.print_exc()
time.sleep(1)
def _register_to_pd_proxy(pd_proxy_zmq_addr, my_http_addr, my_zmq_addr):
context = None
sock = None
while True:
try:
if context is None:
context = zmq.Context() # type: ignore
if sock is None:
sock = context.socket(zmq.DEALER) # type: ignore
identity = f"dp_proxy_{my_http_addr}".encode('utf-8')
sock.setsockopt(zmq.IDENTITY, identity) # type: ignore
sock.setsockopt(zmq.LINGER, 0) # type: ignore
logger.info(
f"DP Decode Proxy: Attempting to connect to PD Proxy at {pd_proxy_zmq_addr}..."
)
sock.connect(f"tcp://{pd_proxy_zmq_addr}")
logger.info(
f"DP Decode Proxy: Connected to PD Proxy at {pd_proxy_zmq_addr}."
)
data = {
"type": "D",
"http_address": my_http_addr,
"zmq_address": my_zmq_addr
}
logger.debug(
f"DP Decode Proxy: Sending registration/heartbeat to PD Proxy: {data}"
)
sock.send(msgpack.dumps(data))
time.sleep(5)
except zmq.ZMQError as e: # type: ignore
logger.error(
f"DP Decode Proxy: ZMQ Error connecting/sending to PD Proxy ({pd_proxy_zmq_addr}): {e}"
)
if sock:
sock.close()
sock = None
time.sleep(10)
except Exception as e:
logger.error(
f"DP Decode Proxy: Unexpected error in PD Proxy registration thread: {e}"
)
import traceback
traceback.print_exc()
if sock:
sock.close()
sock = None
time.sleep(10)
finally:
pass
def start_zmq_thread(hostname, port, socket_type, target_func, thread_name):
"""Generic ZMQ thread starter for ROUTER or PULL."""
if not hostname:
hostname = "0.0.0.0"
context = zmq.Context.instance() # type: ignore
socket = context.socket(socket_type)
socket.setsockopt(zmq.LINGER, 0) # type: ignore
try:
socket.bind(f"tcp://{hostname}:{port}")
except zmq.ZMQError as e: # type: ignore
logger.error(
f"DP Decode Proxy: Error binding ZMQ {socket_type} socket to tcp://{hostname}:{port}: {e}"
)
socket.close()
raise
poller = zmq.Poller() # type: ignore
poller.register(socket, zmq.POLLIN) # type: ignore
thread = threading.Thread(target=target_func,
args=(poller, socket),
daemon=True,
name=thread_name)
thread.start()
return thread, socket
def start_thread_with_event_loop():
global _idle_send_loop
asyncio.set_event_loop(_idle_send_loop)
_idle_send_loop.run_forever() # type: ignore
async def forward_request_internal(url, data, request_id):
try:
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {
"Authorization":
f"Bearer {os.environ.get('OPENAI_API_KEY', '')}",
"X-Request-Id": request_id,
"Content-Type": "application/json"
}
async with session.post(url=url, json=data,
headers=headers) as response:
if response.status == 200:
async for chunk_bytes in response.content.iter_chunked(
1024):
yield chunk_bytes
else:
error_content = await response.read()
logger.warning(
f"DP Decode Proxy: Error from D node {url} (status {response.status}): {error_content.decode(errors='ignore')}"
)
yield error_content
except aiohttp.ClientError as e:
logger.warning(
f"DP Decode Proxy: Error forwarding request {request_id} to D node {url}: {e}"
)
error_msg = f"Failed to connect or communicate with D node at {url}: {e}".encode(
'utf-8')
yield error_msg
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
app = Quart(__name__)
@app.route('/v1/completions', methods=['POST'])
async def handle_request():
global dp_instances
global dp_cv
global round_robin_index
request_received_id = request.headers.get("X-Request-Id")
if not request_received_id:
fallback_id = f"dp_fallback_{random_uuid()}"
logger.warning(
f"DP Decode Proxy: Received request without X-Request-Id header. Using fallback ID: {fallback_id}"
)
request_received_id = fallback_id
else:
logger.info(
f"DP Decode Proxy: Received request from PD Proxy, using propagated ID: {request_received_id}"
)
try:
original_request_data = await request.get_json()
if not original_request_data:
return await make_response("Request body must be valid JSON.", 400)
target_addr = None
with dp_cv:
if not dp_instances:
logger.warning(
f"DP Decode Proxy: Request {request_received_id}: No D Node instances available/registered."
)
return await make_response("No Decode instances available.",
503)
dp_addresses = list(dp_instances.keys())
if not dp_addresses:
logger.error(
f"DP Decode Proxy: Request {request_received_id}: Internal error - dp_instances populated but list is empty."
)
return await make_response("Internal Server Error", 500)
current_selection_index = round_robin_index % len(dp_addresses)
target_addr = dp_addresses[current_selection_index]
round_robin_index += 1
logger.info(
f"DP Decode Proxy: Request {request_received_id}: Routing Decode to D Node {target_addr} (Index {current_selection_index})"
)
target_url = f'http://{target_addr}/v1/completions'
generator = forward_request_internal(target_url, original_request_data,
request_received_id)
response = await make_response(generator)
response.timeout = None
if original_request_data.get("stream", False):
response.headers['Content-Type'] = 'text/event-stream'
response.headers['Cache-Control'] = 'no-cache'
else:
response.headers['Content-Type'] = 'application/json'
logger.debug(
f"DP Decode Proxy: Request {request_received_id}: Streaming response from D node {target_addr}"
)
return response
except Exception as e:
logger.error(
f"DP Decode Proxy: Error handling request {request_received_id}: {e}"
)
return await make_response("Internal Server Error", 500)
if __name__ == '__main__':
d_listener_thread, d_reg_socket = start_zmq_thread(
"0.0.0.0",
DP_PROXY_ZMQ_REG_PORT,
zmq.ROUTER, # type: ignore
_listen_for_d_register, # type: ignore
"DP_DNodeListenerThread")
metadata_thread, notify_socket = start_zmq_thread(
"0.0.0.0",
DP_PROXY_ZMQ_NOTIFY_PORT,
zmq.PULL, # type: ignore
metadata_collect_trigger,
"DP_MetadataMonitorThread")
_idle_send_loop = asyncio.new_event_loop()
idle_loop_thread = threading.Thread(target=start_thread_with_event_loop,
daemon=True,
name="DP_IdleSendLoopThread")
idle_loop_thread.start()
pd_register_thread = threading.Thread(target=_register_to_pd_proxy,
args=(PD_PROXY_ADDRESS,
MY_HTTP_ADDRESS,
MY_ZMQ_ADDRESS_PLACEHOLDER),
daemon=True,
name="DP_PDRegisterThread")
pd_register_thread.start()
logger.info(
f"DP Decode Proxy: Starting Quart web server on http://0.0.0.0:{DP_PROXY_HTTP_PORT}"
)
zmq_context = zmq.Context.instance() # type: ignore
try:
app.run(host='0.0.0.0', port=DP_PROXY_HTTP_PORT)
except KeyboardInterrupt:
logger.info("DP Decode Proxy: KeyboardInterrupt received, stopping...")
except Exception as e:
logger.error(f"DP Decode Proxy: Failed to run Quart server: {e}")
finally:
logger.info("DP Decode Proxy: Shutting down...")
if _idle_send_loop and _idle_send_loop.is_running():
logger.info("DP Decode Proxy: Stopping idle send loop...")
_idle_send_loop.call_soon_threadsafe(_idle_send_loop.stop)
if d_reg_socket:
d_reg_socket.close()
if notify_socket:
notify_socket.close()
if zmq_context:
zmq_context.term()
logger.info("DP Decode Proxy: Shutdown complete.")

View File

@@ -0,0 +1,67 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/examples/offline_inference/basic.py
#
"""
This file provides a function to obtain ips of all NPU Devices in current machine.
"""
import os
import re
import subprocess
import vllm_ascend.envs as envs
# Get all device ips using hccn_tool
HCCN_TOOL_PATH = envs.HCCN_PATH
def get_device_ips(world_size: int):
npu_info = subprocess.run(
["npu-smi", "info", "-m"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
)
if npu_info.returncode != 0 or not os.path.exists(HCCN_TOOL_PATH):
raise RuntimeError("No npu-smi/hccn_tool tools provided for NPU.")
npu_start_idx = int(
re.match(r".*\n\t([0-9]+).*",
npu_info.stdout).group(1)) # type: ignore
device_ip_list = []
for ip_offset in range(world_size):
cmd = [
HCCN_TOOL_PATH,
"-i",
f"{npu_start_idx + ip_offset}",
"-ip",
"-g",
]
device_ip_info = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
)
device_ip = re.match(r"ipaddr:(.*)\n",
device_ip_info.stdout).group(1) # type: ignore
device_ip_list.append(device_ip)
return device_ip_list
# Pass number of NPUs into this function.
print(get_device_ips(8))

View File

@@ -0,0 +1,186 @@
import os
import socket
import threading
import uuid
import aiohttp
import msgpack # type: ignore
import zmq
from quart import Quart, make_response, request
prefill_instances: dict[str, str] = {} # http_address: zmq_address
decode_instances: dict[str, str] = {} # http_address: zmq_address
prefill_cv = threading.Condition()
decode_cv = threading.Condition()
def _listen_for_register(poller, router_socket):
while True:
socks = dict(poller.poll())
if router_socket in socks:
remote_address, message = router_socket.recv_multipart()
# data: {"type": "P", "http_address": "ip:port",
# "zmq_address": "ip:port"}
data = msgpack.loads(message)
if data["type"] == "P":
global prefill_instances
global prefill_cv
with prefill_cv:
prefill_instances[
data["http_address"]] = data["zmq_address"]
print(
"Get a prefill register with http_addr %s and zmq_addr %s",
data["http_address"],
data["zmq_address"],
)
elif data["type"] == "D":
global decode_instances
global decode_cv
with decode_cv:
decode_instances[
data["http_address"]] = data["zmq_address"]
print(
"Get a decode register with http_addr %s and zmq_addr %s",
data["http_address"],
data["zmq_address"],
)
else:
print(
"Unexpected, Received message from %s, data: %s",
remote_address,
data,
)
def start_service_discovery(hostname, port):
if not hostname:
hostname = socket.gethostname()
if port == 0:
raise ValueError("Port cannot be 0")
context = zmq.Context() # type: ignore
router_socket = context.socket(zmq.ROUTER) # type: ignore
router_socket.bind(f"tcp://{hostname}:{port}")
poller = zmq.Poller() # type: ignore
poller.register(router_socket, zmq.POLLIN) # type: ignore
_listener_thread = threading.Thread(target=_listen_for_register,
args=[poller, router_socket],
daemon=True)
_listener_thread.start()
return _listener_thread
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
app = Quart(__name__)
def random_uuid() -> str:
return str(uuid.uuid4().hex)
async def forward_request(url, data, request_id):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id,
}
async with session.post(url=url, json=data,
headers=headers) as response:
if response.status == 200:
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes
@app.route("/v1/completions", methods=["POST"])
async def handle_request():
try:
original_request_data = await request.get_json()
prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill
prefill_request["max_tokens"] = 1
global prefill_instances
global prefill_cv
with prefill_cv:
if len(prefill_instances) > 1:
print(
"Found more than 1 Prefill instances. Currently we only support 1P1D, so only"
f"the first Prefill instance({list(prefill_instances.keys())[0]}) will be used!"
)
if len(prefill_instances) == 0:
res_str = (
"No Prefill instances has been registered to proxy. Please confirm that you have successfully"
" and correctly started a Prefill vLLM instance.")
print(res_str)
response = await make_response(res_str)
return response
# prefill_addr, prefill_zmq_addr = random.choice(
# list(prefill_instances.items()))
prefill_addr, prefill_zmq_addr = list(prefill_instances.items())[0]
print(
"handle_request, prefill_addr: %s, zmq_addr: %s",
prefill_addr,
prefill_zmq_addr,
)
global decode_instances
global decode_cv
with decode_cv:
if len(decode_instances) > 1:
print(
"Found more than 1 Decode instances. Currently we only support 1P1D, so only"
f"the first Decode instance({list(decode_instances.keys())[0]}) will be used!"
)
if len(decode_instances) == 0:
res_str = (
"No Decode instances has been registered to proxy. Please confirm that you have successfully"
" and correctly started a Decode vLLM instance.")
print(res_str)
response = await make_response(res_str)
return response
# decode_addr, decode_zmq_addr = random.choice(
# list(decode_instances.items()))
decode_addr, decode_zmq_addr = list(decode_instances.items())[0]
print(
"handle_request, decode_addr: %s, zmq_addr: %s",
decode_addr,
decode_zmq_addr,
)
request_id = f"___prefill_addr_{prefill_addr}___decode_addr_{decode_addr}_{random_uuid()}"
# finish prefill
async for _ in forward_request(f"http://{prefill_addr}/v1/completions",
prefill_request, request_id):
continue
# return decode
generator = forward_request(
f"http://{decode_addr}/v1/completions",
original_request_data,
request_id,
)
response = await make_response(generator)
response.timeout = None
return response
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server")
print(e)
print("".join(traceback.format_exception(*exc_info)))
if __name__ == "__main__":
t = start_service_discovery("0.0.0.0", 30001)
app.run(host="0.0.0.0", port=10001)
t.join()

View File

@@ -0,0 +1,37 @@
export HCCL_IF_IP=2.0.0.0
export GLOO_SOCKET_IFNAME="enp189s0f0"
export TP_SOCKET_IFNAME="enp189s0f0"
export HCCL_SOCKET_IFNAME="enp189s0f0"
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=100
export VLLM_USE_V1=0
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
--host 0.0.0.0 \
--port 20002 \
--tensor-parallel-size 8 \
--seed 1024 \
--served-model-name deepseek \
--max-model-len 2000 \
--max-num-batched-tokens 2000 \
--trust-remote-code \
--gpu-memory-utilization 0.9 \
--kv-transfer-config \
'{"kv_connector": "AscendSimpleConnector",
"kv_buffer_device": "npu",
"kv_role": "kv_consumer",
"kv_parallel_size": 8,
"kv_port":"21001",
"kv_connector_extra_config":
{"prompt_device_ips": ["1.2.3.1", "1.2.3.2", "1.2.3.3", "1.2.3.4", "1.2.3.5", "1.2.3.6", "1.2.3.7", "1.2.3.8"],
"decode_device_ips": ["1.2.3.9", "1.2.3.10", "1.2.3.11", "1.2.3.12", "1.2.3.13", "1.2.3.14", "1.2.3.15", "1.2.3.16"],
"llmdatadist_comm_port": 26000,
"proxy_ip":"3.0.0.0",
"proxy_port":"30001",
"http_port": 10002}
}'

View File

@@ -0,0 +1,37 @@
export HCCL_IF_IP=1.0.0.0
export GLOO_SOCKET_IFNAME="enp189s0f0"
export TP_SOCKET_IFNAME="enp189s0f0"
export HCCL_SOCKET_IFNAME="enp189s0f0"
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=100
export VLLM_USE_V1=0
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
--host 0.0.0.0 \
--port 10002 \
--tensor-parallel-size 8 \
--seed 1024 \
--served-model-name deepseek \
--max-model-len 2000 \
--max-num-batched-tokens 2000 \
--trust-remote-code \
--gpu-memory-utilization 0.9 \
--kv-transfer-config \
'{"kv_connector": "AscendSimpleConnector",
"kv_buffer_device": "npu",
"kv_role": "kv_producer",
"kv_parallel_size": 8,
"kv_port":"11001",
"kv_connector_extra_config":
{"prompt_device_ips": ["1.2.3.1", "1.2.3.2", "1.2.3.3", "1.2.3.4", "1.2.3.5", "1.2.3.6", "1.2.3.7", "1.2.3.8"],
"decode_device_ips": ["1.2.3.9", "1.2.3.10", "1.2.3.11", "1.2.3.12", "1.2.3.13", "1.2.3.14", "1.2.3.15", "1.2.3.16"],
"llmdatadist_comm_port": 26000,
"proxy_ip":"3.0.0.0",
"proxy_port":"30001",
"http_port": 10002}
}'

30
examples/run_dp_server.sh Normal file
View File

@@ -0,0 +1,30 @@
export HCCL_IF_IP=2.0.0.0
export GLOO_SOCKET_IFNAME="enp189s0f0"
export TP_SOCKET_IFNAME="enp189s0f0"
export HCCL_SOCKET_IFNAME="enp189s0f0"
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=100
export VLLM_USE_V1=0
export ASCEND_RT_VISIBLE_DEVICES=0,1
export VLLM_DP_SIZE=2
export VLLM_DP_RANK=0
export VLLM_DP_MASTER_IP="2.0.0.0"
export VLLM_DP_MASTER_PORT=40001
export VLLM_DP_PROXY_IP="2.0.0.0"
export VLLM_DP_PROXY_PORT=30002
export VLLM_DP_MONITOR_PORT=30003
export VLLM_HTTP_PORT=20001
vllm serve /data/weights/Qwen2.5-0.5B-Instruct \
--host 0.0.0.0 \
--port 20001 \
--tensor-parallel-size 1 \
--seed 1024 \
--served-model-name Qwen \
--max-model-len 2000 \
--max-num-batched-tokens 2000 \
--trust-remote-code \
--gpu-memory-utilization 0.9 \

View File

@@ -13,3 +13,7 @@ torch-npu==2.5.1
torch>=2.5.1
torchvision<0.21.0
wheel
# requirements for disaggregated prefill
msgpack
quart

View File

@@ -123,7 +123,7 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
"model_name": QUANT_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.85
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@@ -169,7 +169,7 @@ def test_mtp_e2e_quant_greedy_correctness(vllm_runner, common_llm_kwargs,
"model_name": FLOAT_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.85
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@@ -230,7 +230,7 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
# Main model
"model_name": FLOAT_MODEL,
"gpu_memory_utilization": 0.85
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@@ -274,7 +274,7 @@ def test_mtp_e2e_greedy_correctness_torchair_graph(
# Main model
"model_name": QUANT_MODEL,
"gpu_memory_utilization": 0.85
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@@ -322,7 +322,7 @@ def test_mtp_e2e_quant_greedy_correctness_torchair_graph(
"model_name": FLOAT_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.9
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@@ -369,7 +369,7 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
"model_name": FLOAT_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.9
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@@ -420,7 +420,7 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
"model_name": FLOAT_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.9
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])

View File

@@ -1,6 +1,27 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from vllm.distributed.kv_transfer.kv_connector.factory import \
KVConnectorFactory
KVConnectorFactory.register_connector(
"AscendHcclConnector", "vllm_ascend.distributed.llmdatadist_connector",
"LLMDataDistConnector")
KVConnectorFactory.register_connector(
"AscendSimpleConnector",
"vllm_ascend.distributed.kv_transfer.simple_connector", "SimpleConnector")

View File

@@ -0,0 +1,209 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import zlib
from typing import List, Optional
import llm_datadist # type: ignore
import torch
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import \
KVLookupBufferBase
from vllm.logger import init_logger
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
from vllm_ascend.distributed.kv_transfer.utils import TORCH_DTYPE_TO_NPU_DTYPE
logger = init_logger(__name__)
# Hash a string into a int32 value.
def int32_hash(data):
assert isinstance(data, str)
data = data.encode("utf-8")
return zlib.adler32(data)
class SimpleBuffer(KVLookupBufferBase):
def __init__(self, data_pipe: SimplePipe):
self.data_pipe = data_pipe
# Consumer buffer need these information to construct receiving buffer.
self.num_layers = None
self.num_heads = None
self.head_size = None
self.dtype = None
self.hidden_size = None
self.key_buffer = None
self.value_buffer = None
self.hidden_buffer = None
def insert(
self,
input_tokens: torch.Tensor,
roi: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
hidden: torch.Tensor,
req_id: str,
) -> None:
"""
seq_len: num_tokens of current request.
input_tokens: [seq_len]
roi: [seq_len]
key: [num_layers, seq_len, num_kv_heads, head_size]
value: [num_layers, seq_len, num_kv_heads, head_size]
hidden: [seq_len, hidden_size]
"""
orig_k_shape = key.shape
num_layers = orig_k_shape[0]
# unsequeeze all tensors to make first dim to 1.
# This is because D node can only pull one batch data from P.
# So we make first dim to 1 here in order to pull full data.
key = key.view(num_layers, -1).unsqueeze(0)
value = value.view(num_layers, -1).unsqueeze(0)
hidden = hidden.unsqueeze(0)
hidden_dtype = key.dtype
# initialize LLMDatadist data structure
key_desc = llm_datadist.CacheDesc(
1,
key.shape,
TORCH_DTYPE_TO_NPU_DTYPE[hidden_dtype],
seq_len_dim_index=1,
)
value_desc = llm_datadist.CacheDesc(
1,
value.shape,
TORCH_DTYPE_TO_NPU_DTYPE[hidden_dtype],
seq_len_dim_index=1,
)
hidden_desc = llm_datadist.CacheDesc(
1,
hidden.shape,
TORCH_DTYPE_TO_NPU_DTYPE[hidden_dtype],
seq_len_dim_index=-1,
)
req_id = int32_hash(req_id)
key_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
req_id, 1)
value_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
req_id, 2)
hidden_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
req_id, 3)
# Currently we use hash value of request id as key, so no need to send input_tokens
self.key_buffer = self.data_pipe.send_tensor(key, key_desc,
key_cache_key)
self.value_buffer = self.data_pipe.send_tensor(value, value_desc,
value_cache_key)
self.hidden_buffer = self.data_pipe.send_tensor(
hidden, hidden_desc, hidden_cache_key)
def drop_select(
self,
input_tokens: torch.Tensor,
roi: Optional[torch.Tensor],
req_id: str,
) -> List[Optional[torch.Tensor]]:
"""Select and *drop* KV cache entries from the lookup buffer.
The functionality is similar to the following python statements
```
ret = buffer.pop(input_tokens, roi)
return ret
```
Args:
input_tokens (torch.Tensor): token IDs.
roi (torch.Tensor): A binary mask on top of the input tokens
Returns:
A list of tensors including:
key: [num_layers, num_tokens, num_heads, head_size]
value: [num_layers, num_tokens, num_heads, head_size]
hidden_or_intermediate_states: [num_tokens, hidden_size]
roi: None (Currently we don't supported roi)
"""
orig_req_id = req_id
req_id = int32_hash(req_id)
num_tokens = input_tokens.shape[0]
kv_shape = (
1,
self.num_layers,
num_tokens * self.num_heads * self.head_size,
)
hidden_shape = (1, num_tokens, self.hidden_size)
key_desc = llm_datadist.CacheDesc(
1,
kv_shape,
TORCH_DTYPE_TO_NPU_DTYPE[self.dtype],
seq_len_dim_index=-1,
)
value_desc = llm_datadist.CacheDesc(
1,
kv_shape,
TORCH_DTYPE_TO_NPU_DTYPE[self.dtype],
seq_len_dim_index=-1,
)
hidden_desc = llm_datadist.CacheDesc(
1,
hidden_shape,
TORCH_DTYPE_TO_NPU_DTYPE[self.dtype],
seq_len_dim_index=-1,
)
key_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
req_id, 1)
value_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
req_id, 2)
hidden_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
req_id, 3)
# Deallocate buffer allocated in last round.
if self.key_buffer:
try:
self.data_pipe.deallocate_buffer(self.key_buffer)
self.data_pipe.deallocate_buffer(self.value_buffer)
self.data_pipe.deallocate_buffer(self.hidden_buffer)
except Exception as e:
logger.warning(
f"Failed to free kv cache buffer, Error code: {str(e)}")
try:
self.key_buffer, key = self.data_pipe.recv_tensor(
key_desc, key_cache_key)
self.value_buffer, value = self.data_pipe.recv_tensor(
value_desc, value_cache_key)
self.hidden_buffer, hidden = self.data_pipe.recv_tensor(
hidden_desc, hidden_cache_key)
key = key.view(self.num_layers, num_tokens, self.num_heads,
self.head_size)
value = value.view(self.num_layers, num_tokens, self.num_heads,
self.head_size)
hidden = hidden.view(num_tokens, self.hidden_size)
except Exception as e:
logger.warning(
f"Faile to receive kv cache and hidden states of request: {orig_req_id} "
f"Error is {str(e)}")
return [None, None, None, None]
return [key, value, hidden, roi]
def close(self):
pass

View File

@@ -0,0 +1,376 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
import torch_npu
import vllm.envs as vllm_envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.parallel_state import get_dp_group
from vllm.logger import logger
from vllm.sequence import IntermediateTensors
from vllm_ascend.distributed.kv_transfer.simple_buffer import SimpleBuffer
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
class SimpleConnector(KVConnectorBase):
def __init__(
self,
rank: int,
local_rank: int,
config: VllmConfig,
):
self.config = config
self.model_config = config.model_config.hf_config
self.tp_size = config.parallel_config.tensor_parallel_size
self.rank = rank
self.local_rank = local_rank
self.is_deepseek_mla = config.model_config.is_deepseek_mla
self.use_mla_opt = not vllm_envs.VLLM_MLA_DISABLE
self.n_layer = self.config.model_config.get_num_layers(
self.config.parallel_config)
self.producer_data_pipe: Optional[SimplePipe]
self.consumer_data_pipe: Optional[SimplePipe]
self.producer_buffer: Optional[SimpleBuffer]
self.consumer_buffer: Optional[SimpleBuffer]
if self.config.kv_transfer_config.is_kv_producer:
self.producer_data_pipe = SimplePipe(
rank=rank,
local_rank=local_rank,
kv_transfer_config=config.kv_transfer_config,
hostname="",
port_offset=rank,
)
self.producer_buffer = SimpleBuffer(self.producer_data_pipe)
else:
self.consumer_data_pipe = SimplePipe(
rank=rank,
local_rank=local_rank,
kv_transfer_config=config.kv_transfer_config,
hostname="",
port_offset=rank,
)
self.consumer_buffer = SimpleBuffer(self.consumer_data_pipe)
def select(
self,
input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor],
req_id: str,
) -> List[Optional[torch.Tensor]]:
assert self.consumer_buffer is not None, (
"Please initialize the "
"consumer buffer before calling select.")
return self.consumer_buffer.drop_select(input_tokens, roi, req_id)
def insert(
self,
input_tokens: torch.Tensor,
roi: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
hidden: torch.Tensor,
req_id: str,
) -> None:
assert self.producer_buffer is not None, (
"Please initialize the "
"producer buffer before calling insert.")
self.producer_buffer.insert(input_tokens, roi, keys, values, hidden,
req_id)
def send_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
model_config = self.model_config
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
# Deepseek's MLA (Multi-head Latent Attention) uses two different
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
# kv_lora_rank + qk_rope_head_dim].
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
# to a kv_cache shape of [2, num_blks, blk_size,
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
# For more details, see vllm/attention/backends/mla/common.py.
if self.is_deepseek_mla and self.use_mla_opt:
head_size = (model_config.kv_lora_rank +
model_config.qk_rope_head_dim)
num_heads = 1
elif self.is_deepseek_mla and not self.use_mla_opt:
head_size = (model_config.qk_nope_head_dim +
model_config.qk_rope_head_dim)
else:
head_size = getattr(
model_config,
"head_dim",
int(hidden_size // num_attention_heads),
)
# Enumerate over all requests and insert them one by one.
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
if start_pos >= num_prefill_tokens:
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
# - input_tokens[num_prefill_tokens:] contains decode tokens.
logger.warning("You have some decode requests while using "
"SimpleConnector. Their KVCache won't be sent.")
break
current_tokens = input_tokens_tensor[start_pos:end_pos]
keys, values = [], []
for layer_id in range(start_layer, end_layer):
kv_cache = kv_caches[layer_id - start_layer]
if self.is_deepseek_mla and self.use_mla_opt:
key_cache = kv_cache.reshape(-1, num_heads, head_size)
value_cache = kv_cache.reshape(-1, num_heads, head_size)
else:
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
values.append(value_cache[current_slot_mapping].unsqueeze(0))
# shape: [num_layers, num_tokens, num_heads, head_size]
keys = torch.cat(keys, dim=0)
values = torch.cat(values, dim=0)
cur_req_id = list(model_input.request_ids_to_seq_ids.keys())[idx]
# Currently we haven't considered situation of roi, pass None here.
self.insert(
current_tokens,
None,
keys,
values,
hidden_or_intermediate_states[start_pos:end_pos],
cur_req_id,
)
logger.info("[rank%d][P]: KV send DONE.", torch.distributed.get_rank())
def recv_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor],
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata", ]:
bypass_model_exec = True
model_config = self.model_config
# get model config
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
num_heads, head_dim = kv_caches[0].shape[-2:]
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
num_layers = end_layer - start_layer
if self.is_deepseek_mla and self.use_mla_opt:
head_size = (model_config.kv_lora_rank +
model_config.qk_rope_head_dim)
num_heads = 1
elif self.is_deepseek_mla and not self.use_mla_opt:
head_size = (model_config.qk_nope_head_dim +
model_config.qk_rope_head_dim)
else:
head_size = getattr(
model_config,
"head_dim",
int(hidden_size // num_attention_heads),
)
self.consumer_buffer.num_heads = num_heads # type: ignore
self.consumer_buffer.num_layers = num_layers # type: ignore
self.consumer_buffer.head_size = head_size # type: ignore
self.consumer_buffer.dtype = kv_caches[0].dtype # type: ignore
self.consumer_buffer.hidden_size = hidden_size # type: ignore
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
total_tokens = model_input.attn_metadata.num_prefill_tokens + model_input.attn_metadata.num_decode_tokens
hidden_or_intermediate_states_for_one_req = []
input_tokens_list = []
num_computed_tokens_list = []
start_pos_list = []
# enumerate different requests
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
if start_pos >= num_prefill_tokens:
logger.warning("You should set --enable_chunked_prefill=False "
"and --max_num_batched_tokens "
"should be equal to --max_seq_len_to_capture")
bypass_model_exec = False
assert start_pos == num_prefill_tokens
break
current_tokens = input_tokens_tensor[start_pos:end_pos]
num_tokens = slen
# collecting data for rebuilding the input
input_tokens_list.append(current_tokens)
start_pos_list.append(start_pos)
cur_req_id = list(model_input.request_ids_to_seq_ids.keys())[idx]
ret = self.select(
current_tokens,
torch.ones_like(current_tokens, dtype=bool),
cur_req_id,
)
if ret[0] is None:
# didn't find any match.
bypass_model_exec = False
num_computed_tokens_list.append(0)
continue
keys: torch.Tensor = ret[0]
values: torch.Tensor = ret[1]
hidden: torch.Tensor = ret[2]
num_computed_tokens = keys.shape[1]
num_computed_tokens_list.append(num_computed_tokens)
# check if both KV cache and the hidden states are received
# If not, need to redo the forwarding to compute missing states
if not all([(num_computed_tokens == num_tokens), hidden is not None
]):
bypass_model_exec = False
# update the end position based on how many tokens are cached.
end_pos = start_pos + num_computed_tokens
# put received KV caches into paged memory
for i in range(
model_executable.model.start_layer,
model_executable.model.end_layer,
):
kv_cache = kv_caches[i - model_executable.model.start_layer]
layer = model_executable.model.layers[i]
if self.is_deepseek_mla and self.use_mla_opt:
layer.self_attn.attn = layer.self_attn.mla_attn
key_cache = kv_cache
slots = slot_mapping[start_pos:end_pos]
sliced_key = keys[i - model_executable.model.start_layer]
torch_npu._npu_reshape_and_cache_siso(key=sliced_key,
key_cache=key_cache,
slot_indices=slots)
else:
key_cache, value_cache = kv_cache[0], kv_cache[1]
sliced_key = keys[i - model_executable.model.start_layer]
sliced_value = values[i -
model_executable.model.start_layer]
torch_npu._npu_reshape_and_cache(
key=sliced_key,
value=sliced_value,
key_cache=key_cache,
value_cache=value_cache,
slot_indices=slot_mapping[start_pos:end_pos],
)
hidden_or_intermediate_states_for_one_req.append(hidden)
if not bypass_model_exec:
# Some of the KV cache is not retrieved
# Here we will fall back to normal model forwarding
# But optionally you can adjust model_input so that you only do
# prefilling on those tokens that are missing KV caches.
if get_dp_group().world_size > 1:
bypass_model_exec = True
hidden_or_intermediate_states = torch.empty(
[total_tokens, hidden_size],
dtype=kv_caches[0].dtype,
device=kv_caches[0].device)
logger.warning(
"[Detect there is more one DP rank in this decode node, in this scenario, no recompute is expected when kv cache dose not received.]"
)
else:
logger.warning(
"[rank%d]: Failed to receive all KVs and hidden "
"states, redo model forwarding.",
torch.distributed.get_rank())
hidden_or_intermediate_states = None
else:
logger.debug(
"[rank%d]: Successfully received all KVs and hidden "
"states, skip model forwarding.",
torch.distributed.get_rank(),
)
# Can't directly concat here which might cause error when bs = 1.
# hidden_or_intermediate_states = torch.empty(total_num_tokens, hidden_size, dtype=kv_caches[0].dtype, device=kv_caches[0].device)
if len(hidden_or_intermediate_states_for_one_req) == 1:
hidden = hidden_or_intermediate_states_for_one_req[0]
tmp_indice = torch.tensor([0] * hidden.shape[0],
dtype=torch.int64).npu()
hidden_or_intermediate_states = torch.empty_like(hidden)
torch_npu.scatter_update_(
hidden_or_intermediate_states,
tmp_indice,
hidden,
axis=-1,
)
else:
hidden_or_intermediate_states = torch.cat(
hidden_or_intermediate_states_for_one_req, dim=0)
return hidden_or_intermediate_states, bypass_model_exec, model_input
def close(self):
self.producer_data_pipe.close() # type: ignore
self.consumer_data_pipe.close() # type: ignore
self.producer_buffer.close() # type: ignore
self.consumer_buffer.close() # type: ignore

View File

@@ -0,0 +1,209 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import threading
import time
from typing import Optional
import llm_datadist # type: ignore
import msgpack # type: ignore
import torch
import torch_npu
import torchair # type: ignore
import zmq # type: ignore
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.logger import init_logger
from vllm.utils import get_ip
import vllm_ascend.envs as envs
from vllm_ascend.distributed.kv_transfer.utils import NPU_DTYPE_TO_TORCH_DTYPE
logger = init_logger(__name__)
class SimplePipe(KVPipeBase):
def __init__(
self,
rank,
local_rank,
kv_transfer_config,
hostname: str = "",
port_offset: int = 0, # NPU offset in current P/D instance.
):
self.rank = rank
self.local_rank = local_rank
# Currently for 1P1D situation, we use cluster_id=0 for both Prefill and Decode
# Will change here in the future to support xPyD.
self.cluster_id = 0
self.config = kv_transfer_config
kv_connector_extra_config = kv_transfer_config.kv_connector_extra_config
kv_role = kv_transfer_config.kv_role
if kv_role == "kv_producer":
self.role = llm_datadist.LLMRole.PROMPT
elif kv_role == "kv_consumer":
self.role = llm_datadist.LLMRole.DECODER
else:
raise NotImplementedError(
"kv_role should be inside [kv_producer, kv_consumer]")
prompt_device_ips = kv_connector_extra_config.get(
"prompt_device_ips", None)
decode_device_ips = kv_connector_extra_config.get(
"decode_device_ips", None)
if prompt_device_ips is None or decode_device_ips is None:
raise ValueError(
"Please specify prompt_device_ips and decode_device_ips"
"in kv_transfer_config.kv_connector_extra_config")
p_device_num = len(prompt_device_ips)
d_device_num = len(decode_device_ips)
# When number of devices in P and D is not equal,
# we assume that device in D can be mapped to any device in P.
self.p_device_rank = self.rank % p_device_num
self.d_device_rank = self.rank % d_device_num
self.prompt_ip_list = prompt_device_ips
self.decode_ip_list = decode_device_ips
self.llmdatadist_comm_port = kv_connector_extra_config.get(
"llmdatadist_comm_port", 26000)
# LLMDataDist initializing.
self.data_dist = llm_datadist.LLMDataDist(self.role, self.cluster_id)
self._prepare_data_dist()
# Decoder needs to initialize and link cluster
if self.role == llm_datadist.LLMRole.DECODER:
self.cluster = self._make_cluster()
_, ret = self.data_dist.link_clusters([self.cluster], 20000)
logger.info(
f"rank {self.rank}, local_rank {self.local_rank} link, ret={ret}"
)
# If `proxy_ip` or `proxy_port` is `""`,
# then the ping thread will not be enabled.
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
proxy_port = self.config.get_from_extra_config("proxy_port", "")
if proxy_ip == "" or proxy_port == "":
self.proxy_address = ""
else:
self.proxy_address = proxy_ip + ":" + proxy_port
self._register_thread = None
if port_offset == 0 and self.proxy_address != "":
# Initialize zmq socket and register to proxy.
# Note that only NPU 0 of each P/D instance register to proxy.
if not hostname:
hostname = get_ip() # Get ip of current host.
port = kv_transfer_config.kv_port + port_offset
if port == 0:
raise ValueError("Port cannot be 0")
self._hostname = hostname
self._port = port
# Each card corresponds to a ZMQ address.
self.zmq_address = f"{self._hostname}:{self._port}"
self.context = zmq.Context() # type: ignore
self.router_socket = self.context.socket(
zmq.ROUTER) # type: ignore
self.router_socket.bind(f"tcp://{self.zmq_address}")
# The `http_port` must be consistent with the serving port of OpenAI.
self.http_address = (
f"{self._hostname}:"
f"{self.config.kv_connector_extra_config['http_port']}")
self._register_thread = threading.Thread(
target=self._register_to_proxy, daemon=True)
self._register_thread.start()
def _prepare_data_dist(self):
options = {
"llm.SyncKvCacheWaitTime": envs.LLMDATADIST_SYNC_CACHE_WAIT_TIME,
}
if self.role == llm_datadist.LLMRole.PROMPT:
options["ge.exec.deviceId"] = str(self.local_rank)
options["llm.listenIpInfo"] = (
f"{self.prompt_ip_list[self.p_device_rank]}:{self.llmdatadist_comm_port}"
)
else:
options["ge.exec.deviceId"] = str(self.local_rank)
print(f"prepare datadist, options: {options}")
self.data_dist.init(options)
self.kv_transfer = self.data_dist.kv_cache_manager
print(f"{self.rank} rank data dist is ready")
def _make_cluster(self):
cluster = llm_datadist.LLMClusterInfo()
cluster.remote_cluster_id = self.cluster_id
local_ip = self.decode_ip_list[self.d_device_rank]
remote_ip = self.prompt_ip_list[self.p_device_rank]
cluster.append_local_ip_info(local_ip, 0)
cluster.append_remote_ip_info(remote_ip, self.llmdatadist_comm_port)
return cluster
def _register_to_proxy(self):
sock = self.context.socket(zmq.DEALER) # type: ignore
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) # type: ignore
logger.debug("ping start, zmq_address:%s", self.zmq_address)
sock.connect(f"tcp://{self.proxy_address}")
data = {
"type": "P" if self.config.is_kv_producer else "D",
"http_address": self.http_address,
"zmq_address": self.zmq_address,
}
while True:
sock.send(msgpack.dumps(data))
time.sleep(3)
def send_tensor(
self,
tensor: Optional[torch.Tensor],
tensor_desc: llm_datadist.CacheDesc,
tensor_key: llm_datadist.CacheKey,
) -> llm_datadist.Cache:
buffer = self.kv_transfer.allocate_cache(tensor_desc, [tensor_key])
buffer_addr = buffer.per_device_tensor_addrs[0]
data_tensor = torchair.llm_datadist.create_npu_tensors(
tensor_desc.shape, tensor.dtype, buffer_addr)[0] # type: ignore
update_indices = torch.tensor(
[0] * tensor.shape[0], # type: ignore
dtype=torch.int64).npu()
torch_npu.scatter_update_(data_tensor, update_indices, tensor, axis=-1)
# Free cache_id of buffer, actual deallocate will happen after consumer performing pull_cache.
self.kv_transfer.deallocate_cache(buffer)
return buffer
def recv_tensor(
self,
tensor_desc: llm_datadist.CacheDesc,
tensor_key: llm_datadist.CacheKey,
) -> llm_datadist.Cache:
"""Note that this function only creates empty tensor on buffer addr and returns it."""
tmp_buffer = self.kv_transfer.allocate_cache(tensor_desc)
buffer_addr = tmp_buffer.per_device_tensor_addrs[0]
data_tensor = torchair.llm_datadist.create_npu_tensors(
tensor_desc.shape,
NPU_DTYPE_TO_TORCH_DTYPE[tensor_desc.data_type],
buffer_addr,
)[0]
self.kv_transfer.pull_cache(tensor_key, tmp_buffer, 0)
# tmp_buffer is allocated without key and will be deallocated here immediately.
# Free buffer here will cause accuracy problem.
# self.kv_transfer.deallocate_cache(tmp_buffer)
return tmp_buffer, data_tensor
def deallocate_buffer(self, buffer: llm_datadist.Cache):
self.kv_transfer.deallocate_cache(buffer)
def close(self):
self.data_dist.unlink_clusters([self.cluster], 5000)

View File

@@ -0,0 +1,40 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import llm_datadist # type: ignore
import torch
TORCH_DTYPE_TO_NPU_DTYPE = {
torch.half: llm_datadist.DataType.DT_FLOAT16,
torch.float16: llm_datadist.DataType.DT_FLOAT16,
torch.bfloat16: llm_datadist.DataType.DT_BF16,
torch.float: llm_datadist.DataType.DT_FLOAT,
torch.float32: llm_datadist.DataType.DT_FLOAT,
torch.int8: llm_datadist.DataType.DT_INT8,
torch.int64: llm_datadist.DataType.DT_INT64,
torch.int32: llm_datadist.DataType.DT_INT32,
}
NPU_DTYPE_TO_TORCH_DTYPE = {
llm_datadist.DataType.DT_FLOAT16: torch.half,
llm_datadist.DataType.DT_FLOAT16: torch.float16,
llm_datadist.DataType.DT_BF16: torch.bfloat16,
llm_datadist.DataType.DT_FLOAT: torch.float,
llm_datadist.DataType.DT_FLOAT: torch.float32,
llm_datadist.DataType.DT_INT8: torch.int8,
llm_datadist.DataType.DT_INT64: torch.int64,
llm_datadist.DataType.DT_INT32: torch.int32,
}

View File

@@ -33,7 +33,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import VllmConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_pp_group
from vllm.distributed import get_dp_group, get_pp_group
from vllm.distributed.kv_transfer import get_kv_transfer_group
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
@@ -1343,6 +1343,17 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
kv_caches=kv_caches
)
if get_dp_group().world_size > 1:
bypass_model_exec_tensor = torch.tensor(
1, dtype=torch.int32) if bypass_model_exec else torch.tensor(
0, dtype=torch.int32)
torch.distributed.all_reduce(bypass_model_exec_tensor,
op=torch.distributed.ReduceOp.MIN,
group=get_dp_group().cpu_group)
# If there is any group have not receive the necessary hidden states or kv_cache, we force all the dp group execute.
if bypass_model_exec_tensor.item() == 0:
bypass_model_exec = False
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids,
@@ -1399,10 +1410,21 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
torch.tensor(model_forward_time +
orig_model_forward_time))
return hidden_or_intermediate_states
# TODO: remove the synchronize here
torch.npu.synchronize()
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)
# Sending KV cache in distributed KV cache transfer setting
if self.need_send_kv(model_input, kv_caches):
get_kv_transfer_group().send_kv_caches_and_hidden_states(
# model_executable is used to know which layer the current
# worker is working on, so that we can send KV for only those
# layers.
model_executable,
model_input,
kv_caches,
hidden_or_intermediate_states,
)
if not self.is_driver_worker:
return []

View File

@@ -18,10 +18,13 @@
#
import gc
import os
from typing import Dict, List, Optional, Set, Tuple, Type, Union
import msgpack # type: ignore
import torch
import torch.distributed
import zmq
from torch import nn
from vllm import envs
from vllm.config import VllmConfig, set_current_vllm_config
@@ -37,7 +40,7 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SequenceGroupMetadata, SequenceGroupMetadataDelta)
from vllm.utils import GiB_bytes, bind_kv_cache
from vllm.utils import GiB_bytes, bind_kv_cache, get_ip
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner_base import ModelRunnerBase
@@ -157,6 +160,33 @@ class NPUWorker(LocalOrDistributedWorkerBase):
else:
self.profiler = None
self.enable_dummy_run = False
if os.getenv("VLLM_DP_PROXY_IP", None):
logger.warning("enable dummy run for the DP")
self.enable_dummy_run = True
# dp_rank = os.environ["VLLM_DP_RANK"]
dp_master_ip = os.environ["VLLM_DP_PROXY_IP"]
dp_proxy_listener_port = os.environ["VLLM_DP_PROXY_PORT"]
dp_proxy_monitor_port = os.environ["VLLM_DP_MONITOR_PORT"]
dp_proxy_listener_addr = f"{dp_master_ip}:{dp_proxy_listener_port}"
self.dp_proxy_monitor_addr = f"{dp_master_ip}:{dp_proxy_monitor_port}"
http_ip = get_ip()
port = os.environ["VLLM_HTTP_PORT"]
self.http_addr = f"{http_ip}:{port}"
context = zmq.Context() # type: ignore
sock = context.socket(zmq.DEALER) # type: ignore
logger.debug("ping dp proxy start, DP_RANK:%s", 0)
# logger.debug("ping dp proxy start, DP_RANK:%s", dp_rank)
sock.connect(f"tcp://{dp_proxy_listener_addr}")
data = {"type": "DP", "http_address": self.http_addr}
for _ in range(10):
sock.send(msgpack.dumps(data))
self.notify_socket = context.socket(zmq.PUSH) # type: ignore
self.notify_socket.connect(f"tcp://{self.dp_proxy_monitor_addr}")
def sleep(self, level: int = 1) -> None:
NPUPlatform.set_device(self.device)
free_bytes_before_sleep = NPUPlatform.mem_get_info()[0]
@@ -375,6 +405,11 @@ class NPUWorker(LocalOrDistributedWorkerBase):
@torch.inference_mode()
def execute_worker(self, worker_input: WorkerInput) -> None:
if self.enable_dummy_run:
logger.debug(
f"send notify to the dp proxy: {self.dp_proxy_monitor_addr}")
data = {"info": "notify_step", "http_address": self.http_addr}
self.notify_socket.send(msgpack.dumps(data))
virtual_engine = worker_input.virtual_engine
# Issue cache operations.
if (worker_input.blocks_to_swap_in is not None