[Disaggregated Prefill] P2P Disaggregated Prefill based on llm_datadist (#694)

### What this PR does / why we need it?
- This PR proposes a P2P version of Disaggregated Prefill based on
llm_datadist which manages data transfer.

- This solution reconstructs previous offline single-node Disaggregated
Prefill solution, and supports multi-node and online serveing now.

- Currently this solution supports 1P1D situation of Deepseek hybrid
parallelism (P: TP+EP, D: DP+EP). Note that xPyD situation is considered
in the solution design, and will be supported soon within v1 engine.

---------

Signed-off-by: hw_whx <wanghexiang7@huawei.com>
Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
Co-authored-by: hw_whx <wanghexiang7@huawei.com>
Co-authored-by: ganyi <pleaplusone.gy@gmail.com>
This commit is contained in:
whx
2025-05-01 22:31:36 +08:00
committed by GitHub
parent 84e2ed898b
commit 8b194ad12e
18 changed files with 1769 additions and 32 deletions

View File

@@ -136,18 +136,9 @@ jobs:
id: filter_spec_decode id: filter_spec_decode
uses: dorny/paths-filter@v3 uses: dorny/paths-filter@v3
with: with:
# speculative decode seems will cause oom issue, disable it now on ci test
filters: | filters: |
speculative_tests_changed: speculative_tests_changed: 'false'
- "tests/singlecard/spec_decode/**"
- "tests/multicard/spec_decode_e2e/**"
- "vllm_ascend/worker/worker.py"
- "vllm_ascend/worker/model_runner.py"
- "vllm_ascend/worker/multi_step_runner.py"
- "vllm_ascend/worker/multi_step_worker.py"
- "vllm_ascend/worker/draft_model_runner.py"
- "vllm_ascend/patch/worker/patch_common/patch_metrics.py"
- "vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py"
- "vllm_ascend/patch/worker/patch_common/patch_multi_step_worker.py"
- name: Run vllm-project/vllm-ascend Speculative Decode test - name: Run vllm-project/vllm-ascend Speculative Decode test
if: steps.filter_spec_decode.outputs.speculative_tests_changed == 'true' || github.event_name == 'schedule' if: steps.filter_spec_decode.outputs.speculative_tests_changed == 'true' || github.event_name == 'schedule'

View File

@@ -2,12 +2,22 @@
This file demonstrates the example usage of disaggregated prefilling This file demonstrates the example usage of disaggregated prefilling
We will launch 2 vllm instances (NPU 0,1 for prefill and NPU 2,3 for decode), We will launch 2 vllm instances (NPU 0,1 for prefill and NPU 2,3 for decode),
and then transfer the KV cache between them. and then transfer the KV cache between them.
prompy_device_ips denotes device ip of NPU 0,1
decode_device_ips denotes device ip of NPU 2,3
The device ips of all NPUs in current server can be found through
examples/disaggregated_prefill/find_device_ips.py
""" """
import multiprocessing as mp import multiprocessing as mp
import os import os
import time import time
from multiprocessing import Event, Process from multiprocessing import Event, Process
kv_connector_extra_config = {
"prompt_device_ips": ["1.2.3.1", "1.2.3.2"],
"decode_device_ips": ["1.2.3.9", "1.2.3.10"],
"llmdatadist_comm_port": 26000,
}
def clean_up(): def clean_up():
import gc import gc
@@ -34,11 +44,10 @@ def run_prefill(prefill_done, process_close):
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
ktc = KVTransferConfig.from_cli( ktc = KVTransferConfig.from_cli(
'{"kv_connector":"AscendHcclConnector","kv_buffer_device":"npu","kv_role":"kv_producer", "kv_parallel_size":2}' '{"kv_connector":"AscendSimpleConnector","kv_buffer_device":"npu","kv_role":"kv_producer", "kv_parallel_size":2}'
) )
global kv_connector_extra_config
# Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB ktc.kv_connector_extra_config = kv_connector_extra_config
# memory. You may need to adjust the value to fit your GPU.
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
kv_transfer_config=ktc, kv_transfer_config=ktc,
max_model_len=2000, max_model_len=2000,
@@ -69,15 +78,16 @@ def run_decode(prefill_done):
from vllm.config import KVTransferConfig from vllm.config import KVTransferConfig
prompts = [ prompts = [
"Hello, how are you today?", "Hi, what is your name?", "Hello, how are you today?",
"Tell me a very long story.", "what is your favourite book?" "Hi, what is your name?",
] ]
sampling_params = SamplingParams(temperature=0, top_p=0.95) sampling_params = SamplingParams(temperature=0, top_p=0.95)
ktc = KVTransferConfig.from_cli( ktc = KVTransferConfig.from_cli(
'{"kv_connector":"AscendHcclConnector","kv_buffer_device":"npu","kv_role":"kv_consumer","kv_parallel_size":2}' '{"kv_connector":"AscendSimpleConnector","kv_buffer_device":"npu","kv_role":"kv_consumer","kv_parallel_size":2}'
) )
global kv_connector_extra_config
ktc.kv_connector_extra_config = kv_connector_extra_config
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
kv_transfer_config=ktc, kv_transfer_config=ktc,
max_model_len=2000, max_model_len=2000,

View File

@@ -0,0 +1,463 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import copy
import logging
import os
import threading
import time
import uuid
import aiohttp
import msgpack # type: ignore
import zmq
from quart import Quart, make_response, request
DP_PROXY_HTTP_PORT = 10004
DP_PROXY_ZMQ_REG_PORT = 30006
DP_PROXY_ZMQ_NOTIFY_PORT = 30005
PD_PROXY_ADDRESS = "127.0.0.1:30002"
MY_HTTP_ADDRESS = f"127.0.0.1:{DP_PROXY_HTTP_PORT}"
MY_ZMQ_ADDRESS_PLACEHOLDER = f"127.0.0.1:{DP_PROXY_ZMQ_REG_PORT}"
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
TIME_INTERVAL_FOR_IDLE_RUN = 5e-4
DP_SIZE = 2
dp_instances: dict[str, bool] = {}
dp_cv = threading.Condition()
round_robin_index = 0
_idle_send_loop = None
def make_idle_request():
# Same as before
data = {
"prompt": "hi",
"max_tokens": 1,
"temperature": 0,
}
return data
def random_uuid() -> str:
return str(uuid.uuid4().hex)
async def send_idle_token_to_client(schedule_dict):
for key, value in schedule_dict.items():
if value:
continue
request_received_id = random_uuid()
idle_request_data = make_idle_request()
forward_request_id = f"dp_idle_{key}_{request_received_id}"
target_url = f'http://{key}/v1/completions'
logger.debug(
f"DP Decode Proxy: Sending idle token to D node {key} at {target_url}"
)
generator = forward_request_internal(target_url, idle_request_data,
forward_request_id)
try:
async for response in generator:
logger.debug(
f"DP Decode Proxy: Idle Request {request_received_id}: response from {key}, got response: {response}"
)
except Exception as e:
logger.warning(
f"DP Decode Proxy: Error sending idle token to {key}: {e}")
def metadata_collect_trigger(poller, router_socket):
global dp_instances
global dp_cv
global _idle_send_loop
with dp_cv:
dp_cv.wait()
while True:
try:
schedule_dict = copy.deepcopy(dp_instances)
for key in schedule_dict.keys():
schedule_dict[key] = False
first_start = False
start_time = None
while not all(schedule_dict.values()):
if start_time is not None:
time_interval = time.time() - start_time
logger.debug("check time interval: ", time_interval)
if time_interval > TIME_INTERVAL_FOR_IDLE_RUN:
logger.debug(
"exceeds max time interval send idle token to client"
)
# Send idle token to client in case of single dp rank run solo and block on the CCL part
asyncio.run_coroutine_threadsafe(
send_idle_token_to_client(schedule_dict),
_idle_send_loop) # type: ignore
# Note: Reset start time prevent consistently send idle token to client
# We only reset start time here, for some of the client may loss the idle token send from this proxy
# and we only exit this while loop when we make sure all the client are exactly start inference in this
# step
start_time = time.time()
socks = dict(poller.poll(timeout=500)) # timeout in 500ms
if socks:
logger.debug("receive socks from moniter threads: ", socks)
if router_socket in socks:
messages = router_socket.recv_multipart()
try:
# {"info": "notify_step", "http_address": ""}
for message in messages:
data = msgpack.loads(message)
http_addr = None
logger.debug(f"receive message {data}")
if data.get("info") == "notify_step":
http_addr = data.get("http_address")
if http_addr in schedule_dict.keys():
schedule_dict[http_addr] = True
logger.debug("set first time")
if not first_start:
logger.debug("record start time")
first_start = True
start_time = time.time()
else:
logger.warning("Unrecognize http address")
else:
logger.warning(
"Got unrecognize info type! We only accept notify step info yet"
)
except (msgpack.UnpackException, TypeError, KeyError) as e:
logger.error(
f"Error processing message from {http_addr}: {e}. Message: {data}"
)
except zmq.ZMQError as e: # type: ignore
logger.error(f"ZMQ Error in monitor thread: {e}")
if e.errno == zmq.ETERM: # type: ignore
logger.error(
"Monitor thread terminating due to context termination.")
break
time.sleep(1)
except Exception as e:
logger.error(f"Unexpected error in monitor thread: {e}")
import traceback
traceback.print_exc()
time.sleep(1)
def _listen_for_d_register(poller, router_socket):
global dp_instances
global dp_cv
global DP_SIZE
logger.info(
f"DP Decode Proxy: D Node ZMQ Listener started on ROUTER port {DP_PROXY_ZMQ_REG_PORT}"
)
while True:
try:
socks = dict(poller.poll(timeout=1000))
if router_socket in socks:
remote_id, message = router_socket.recv_multipart()
try:
data = msgpack.loads(message)
if data.get("type") == "DP":
http_addr = data.get("http_address")
zmq_addr = data.get("zmq_address")
if http_addr:
with dp_cv:
if http_addr not in dp_instances:
logger.info(
f"DP Decode Proxy: Registering D Node instance: http={http_addr}, zmq={zmq_addr}"
)
dp_instances[http_addr] = True
if len(dp_instances) >= DP_SIZE:
logger.info(
f"DP Decode Proxy: Reached expected D Node count ({DP_SIZE}). Notifying metadata collector."
)
dp_cv.notify_all()
else:
pass
else:
logger.warning(
f"DP Decode Proxy: Received D Node registration from {remote_id.decode()} without http_address. Data: {data}"
)
else:
logger.warning(
f"DP Decode Proxy: Received message with unexpected type from {remote_id.decode()}. Type: {data.get('type')}, Data: {data}"
)
except (msgpack.UnpackException, TypeError, KeyError) as e:
logger.error(
f"DP Decode Proxy: Error processing D Node registration from {remote_id.decode()}: {e}. Message: {message}"
)
except Exception as e:
logger.error(
f"DP Decode Proxy: Unexpected error processing D Node registration from {remote_id.decode()}: {e}"
)
except zmq.ZMQError as e: # type: ignore
logger.error(
f"DP Decode Proxy: ZMQ Error in D Node listener thread: {e}")
if e.errno == zmq.ETERM: # type: ignore
logger.info(
"DP Decode Proxy: D Node Listener thread terminating.")
break
time.sleep(1)
except Exception as e:
logger.error(
f"DP Decode Proxy: Unexpected error in D Node listener thread: {e}"
)
import traceback
traceback.print_exc()
time.sleep(1)
def _register_to_pd_proxy(pd_proxy_zmq_addr, my_http_addr, my_zmq_addr):
context = None
sock = None
while True:
try:
if context is None:
context = zmq.Context() # type: ignore
if sock is None:
sock = context.socket(zmq.DEALER) # type: ignore
identity = f"dp_proxy_{my_http_addr}".encode('utf-8')
sock.setsockopt(zmq.IDENTITY, identity) # type: ignore
sock.setsockopt(zmq.LINGER, 0) # type: ignore
logger.info(
f"DP Decode Proxy: Attempting to connect to PD Proxy at {pd_proxy_zmq_addr}..."
)
sock.connect(f"tcp://{pd_proxy_zmq_addr}")
logger.info(
f"DP Decode Proxy: Connected to PD Proxy at {pd_proxy_zmq_addr}."
)
data = {
"type": "D",
"http_address": my_http_addr,
"zmq_address": my_zmq_addr
}
logger.debug(
f"DP Decode Proxy: Sending registration/heartbeat to PD Proxy: {data}"
)
sock.send(msgpack.dumps(data))
time.sleep(5)
except zmq.ZMQError as e: # type: ignore
logger.error(
f"DP Decode Proxy: ZMQ Error connecting/sending to PD Proxy ({pd_proxy_zmq_addr}): {e}"
)
if sock:
sock.close()
sock = None
time.sleep(10)
except Exception as e:
logger.error(
f"DP Decode Proxy: Unexpected error in PD Proxy registration thread: {e}"
)
import traceback
traceback.print_exc()
if sock:
sock.close()
sock = None
time.sleep(10)
finally:
pass
def start_zmq_thread(hostname, port, socket_type, target_func, thread_name):
"""Generic ZMQ thread starter for ROUTER or PULL."""
if not hostname:
hostname = "0.0.0.0"
context = zmq.Context.instance() # type: ignore
socket = context.socket(socket_type)
socket.setsockopt(zmq.LINGER, 0) # type: ignore
try:
socket.bind(f"tcp://{hostname}:{port}")
except zmq.ZMQError as e: # type: ignore
logger.error(
f"DP Decode Proxy: Error binding ZMQ {socket_type} socket to tcp://{hostname}:{port}: {e}"
)
socket.close()
raise
poller = zmq.Poller() # type: ignore
poller.register(socket, zmq.POLLIN) # type: ignore
thread = threading.Thread(target=target_func,
args=(poller, socket),
daemon=True,
name=thread_name)
thread.start()
return thread, socket
def start_thread_with_event_loop():
global _idle_send_loop
asyncio.set_event_loop(_idle_send_loop)
_idle_send_loop.run_forever() # type: ignore
async def forward_request_internal(url, data, request_id):
try:
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {
"Authorization":
f"Bearer {os.environ.get('OPENAI_API_KEY', '')}",
"X-Request-Id": request_id,
"Content-Type": "application/json"
}
async with session.post(url=url, json=data,
headers=headers) as response:
if response.status == 200:
async for chunk_bytes in response.content.iter_chunked(
1024):
yield chunk_bytes
else:
error_content = await response.read()
logger.warning(
f"DP Decode Proxy: Error from D node {url} (status {response.status}): {error_content.decode(errors='ignore')}"
)
yield error_content
except aiohttp.ClientError as e:
logger.warning(
f"DP Decode Proxy: Error forwarding request {request_id} to D node {url}: {e}"
)
error_msg = f"Failed to connect or communicate with D node at {url}: {e}".encode(
'utf-8')
yield error_msg
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
app = Quart(__name__)
@app.route('/v1/completions', methods=['POST'])
async def handle_request():
global dp_instances
global dp_cv
global round_robin_index
request_received_id = request.headers.get("X-Request-Id")
if not request_received_id:
fallback_id = f"dp_fallback_{random_uuid()}"
logger.warning(
f"DP Decode Proxy: Received request without X-Request-Id header. Using fallback ID: {fallback_id}"
)
request_received_id = fallback_id
else:
logger.info(
f"DP Decode Proxy: Received request from PD Proxy, using propagated ID: {request_received_id}"
)
try:
original_request_data = await request.get_json()
if not original_request_data:
return await make_response("Request body must be valid JSON.", 400)
target_addr = None
with dp_cv:
if not dp_instances:
logger.warning(
f"DP Decode Proxy: Request {request_received_id}: No D Node instances available/registered."
)
return await make_response("No Decode instances available.",
503)
dp_addresses = list(dp_instances.keys())
if not dp_addresses:
logger.error(
f"DP Decode Proxy: Request {request_received_id}: Internal error - dp_instances populated but list is empty."
)
return await make_response("Internal Server Error", 500)
current_selection_index = round_robin_index % len(dp_addresses)
target_addr = dp_addresses[current_selection_index]
round_robin_index += 1
logger.info(
f"DP Decode Proxy: Request {request_received_id}: Routing Decode to D Node {target_addr} (Index {current_selection_index})"
)
target_url = f'http://{target_addr}/v1/completions'
generator = forward_request_internal(target_url, original_request_data,
request_received_id)
response = await make_response(generator)
response.timeout = None
if original_request_data.get("stream", False):
response.headers['Content-Type'] = 'text/event-stream'
response.headers['Cache-Control'] = 'no-cache'
else:
response.headers['Content-Type'] = 'application/json'
logger.debug(
f"DP Decode Proxy: Request {request_received_id}: Streaming response from D node {target_addr}"
)
return response
except Exception as e:
logger.error(
f"DP Decode Proxy: Error handling request {request_received_id}: {e}"
)
return await make_response("Internal Server Error", 500)
if __name__ == '__main__':
d_listener_thread, d_reg_socket = start_zmq_thread(
"0.0.0.0",
DP_PROXY_ZMQ_REG_PORT,
zmq.ROUTER, # type: ignore
_listen_for_d_register, # type: ignore
"DP_DNodeListenerThread")
metadata_thread, notify_socket = start_zmq_thread(
"0.0.0.0",
DP_PROXY_ZMQ_NOTIFY_PORT,
zmq.PULL, # type: ignore
metadata_collect_trigger,
"DP_MetadataMonitorThread")
_idle_send_loop = asyncio.new_event_loop()
idle_loop_thread = threading.Thread(target=start_thread_with_event_loop,
daemon=True,
name="DP_IdleSendLoopThread")
idle_loop_thread.start()
pd_register_thread = threading.Thread(target=_register_to_pd_proxy,
args=(PD_PROXY_ADDRESS,
MY_HTTP_ADDRESS,
MY_ZMQ_ADDRESS_PLACEHOLDER),
daemon=True,
name="DP_PDRegisterThread")
pd_register_thread.start()
logger.info(
f"DP Decode Proxy: Starting Quart web server on http://0.0.0.0:{DP_PROXY_HTTP_PORT}"
)
zmq_context = zmq.Context.instance() # type: ignore
try:
app.run(host='0.0.0.0', port=DP_PROXY_HTTP_PORT)
except KeyboardInterrupt:
logger.info("DP Decode Proxy: KeyboardInterrupt received, stopping...")
except Exception as e:
logger.error(f"DP Decode Proxy: Failed to run Quart server: {e}")
finally:
logger.info("DP Decode Proxy: Shutting down...")
if _idle_send_loop and _idle_send_loop.is_running():
logger.info("DP Decode Proxy: Stopping idle send loop...")
_idle_send_loop.call_soon_threadsafe(_idle_send_loop.stop)
if d_reg_socket:
d_reg_socket.close()
if notify_socket:
notify_socket.close()
if zmq_context:
zmq_context.term()
logger.info("DP Decode Proxy: Shutdown complete.")

View File

@@ -0,0 +1,67 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/examples/offline_inference/basic.py
#
"""
This file provides a function to obtain ips of all NPU Devices in current machine.
"""
import os
import re
import subprocess
import vllm_ascend.envs as envs
# Get all device ips using hccn_tool
HCCN_TOOL_PATH = envs.HCCN_PATH
def get_device_ips(world_size: int):
npu_info = subprocess.run(
["npu-smi", "info", "-m"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
)
if npu_info.returncode != 0 or not os.path.exists(HCCN_TOOL_PATH):
raise RuntimeError("No npu-smi/hccn_tool tools provided for NPU.")
npu_start_idx = int(
re.match(r".*\n\t([0-9]+).*",
npu_info.stdout).group(1)) # type: ignore
device_ip_list = []
for ip_offset in range(world_size):
cmd = [
HCCN_TOOL_PATH,
"-i",
f"{npu_start_idx + ip_offset}",
"-ip",
"-g",
]
device_ip_info = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
)
device_ip = re.match(r"ipaddr:(.*)\n",
device_ip_info.stdout).group(1) # type: ignore
device_ip_list.append(device_ip)
return device_ip_list
# Pass number of NPUs into this function.
print(get_device_ips(8))

View File

@@ -0,0 +1,186 @@
import os
import socket
import threading
import uuid
import aiohttp
import msgpack # type: ignore
import zmq
from quart import Quart, make_response, request
prefill_instances: dict[str, str] = {} # http_address: zmq_address
decode_instances: dict[str, str] = {} # http_address: zmq_address
prefill_cv = threading.Condition()
decode_cv = threading.Condition()
def _listen_for_register(poller, router_socket):
while True:
socks = dict(poller.poll())
if router_socket in socks:
remote_address, message = router_socket.recv_multipart()
# data: {"type": "P", "http_address": "ip:port",
# "zmq_address": "ip:port"}
data = msgpack.loads(message)
if data["type"] == "P":
global prefill_instances
global prefill_cv
with prefill_cv:
prefill_instances[
data["http_address"]] = data["zmq_address"]
print(
"Get a prefill register with http_addr %s and zmq_addr %s",
data["http_address"],
data["zmq_address"],
)
elif data["type"] == "D":
global decode_instances
global decode_cv
with decode_cv:
decode_instances[
data["http_address"]] = data["zmq_address"]
print(
"Get a decode register with http_addr %s and zmq_addr %s",
data["http_address"],
data["zmq_address"],
)
else:
print(
"Unexpected, Received message from %s, data: %s",
remote_address,
data,
)
def start_service_discovery(hostname, port):
if not hostname:
hostname = socket.gethostname()
if port == 0:
raise ValueError("Port cannot be 0")
context = zmq.Context() # type: ignore
router_socket = context.socket(zmq.ROUTER) # type: ignore
router_socket.bind(f"tcp://{hostname}:{port}")
poller = zmq.Poller() # type: ignore
poller.register(router_socket, zmq.POLLIN) # type: ignore
_listener_thread = threading.Thread(target=_listen_for_register,
args=[poller, router_socket],
daemon=True)
_listener_thread.start()
return _listener_thread
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
app = Quart(__name__)
def random_uuid() -> str:
return str(uuid.uuid4().hex)
async def forward_request(url, data, request_id):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id,
}
async with session.post(url=url, json=data,
headers=headers) as response:
if response.status == 200:
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes
@app.route("/v1/completions", methods=["POST"])
async def handle_request():
try:
original_request_data = await request.get_json()
prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill
prefill_request["max_tokens"] = 1
global prefill_instances
global prefill_cv
with prefill_cv:
if len(prefill_instances) > 1:
print(
"Found more than 1 Prefill instances. Currently we only support 1P1D, so only"
f"the first Prefill instance({list(prefill_instances.keys())[0]}) will be used!"
)
if len(prefill_instances) == 0:
res_str = (
"No Prefill instances has been registered to proxy. Please confirm that you have successfully"
" and correctly started a Prefill vLLM instance.")
print(res_str)
response = await make_response(res_str)
return response
# prefill_addr, prefill_zmq_addr = random.choice(
# list(prefill_instances.items()))
prefill_addr, prefill_zmq_addr = list(prefill_instances.items())[0]
print(
"handle_request, prefill_addr: %s, zmq_addr: %s",
prefill_addr,
prefill_zmq_addr,
)
global decode_instances
global decode_cv
with decode_cv:
if len(decode_instances) > 1:
print(
"Found more than 1 Decode instances. Currently we only support 1P1D, so only"
f"the first Decode instance({list(decode_instances.keys())[0]}) will be used!"
)
if len(decode_instances) == 0:
res_str = (
"No Decode instances has been registered to proxy. Please confirm that you have successfully"
" and correctly started a Decode vLLM instance.")
print(res_str)
response = await make_response(res_str)
return response
# decode_addr, decode_zmq_addr = random.choice(
# list(decode_instances.items()))
decode_addr, decode_zmq_addr = list(decode_instances.items())[0]
print(
"handle_request, decode_addr: %s, zmq_addr: %s",
decode_addr,
decode_zmq_addr,
)
request_id = f"___prefill_addr_{prefill_addr}___decode_addr_{decode_addr}_{random_uuid()}"
# finish prefill
async for _ in forward_request(f"http://{prefill_addr}/v1/completions",
prefill_request, request_id):
continue
# return decode
generator = forward_request(
f"http://{decode_addr}/v1/completions",
original_request_data,
request_id,
)
response = await make_response(generator)
response.timeout = None
return response
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server")
print(e)
print("".join(traceback.format_exception(*exc_info)))
if __name__ == "__main__":
t = start_service_discovery("0.0.0.0", 30001)
app.run(host="0.0.0.0", port=10001)
t.join()

View File

@@ -0,0 +1,37 @@
export HCCL_IF_IP=2.0.0.0
export GLOO_SOCKET_IFNAME="enp189s0f0"
export TP_SOCKET_IFNAME="enp189s0f0"
export HCCL_SOCKET_IFNAME="enp189s0f0"
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=100
export VLLM_USE_V1=0
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
--host 0.0.0.0 \
--port 20002 \
--tensor-parallel-size 8 \
--seed 1024 \
--served-model-name deepseek \
--max-model-len 2000 \
--max-num-batched-tokens 2000 \
--trust-remote-code \
--gpu-memory-utilization 0.9 \
--kv-transfer-config \
'{"kv_connector": "AscendSimpleConnector",
"kv_buffer_device": "npu",
"kv_role": "kv_consumer",
"kv_parallel_size": 8,
"kv_port":"21001",
"kv_connector_extra_config":
{"prompt_device_ips": ["1.2.3.1", "1.2.3.2", "1.2.3.3", "1.2.3.4", "1.2.3.5", "1.2.3.6", "1.2.3.7", "1.2.3.8"],
"decode_device_ips": ["1.2.3.9", "1.2.3.10", "1.2.3.11", "1.2.3.12", "1.2.3.13", "1.2.3.14", "1.2.3.15", "1.2.3.16"],
"llmdatadist_comm_port": 26000,
"proxy_ip":"3.0.0.0",
"proxy_port":"30001",
"http_port": 10002}
}'

View File

@@ -0,0 +1,37 @@
export HCCL_IF_IP=1.0.0.0
export GLOO_SOCKET_IFNAME="enp189s0f0"
export TP_SOCKET_IFNAME="enp189s0f0"
export HCCL_SOCKET_IFNAME="enp189s0f0"
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=100
export VLLM_USE_V1=0
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
--host 0.0.0.0 \
--port 10002 \
--tensor-parallel-size 8 \
--seed 1024 \
--served-model-name deepseek \
--max-model-len 2000 \
--max-num-batched-tokens 2000 \
--trust-remote-code \
--gpu-memory-utilization 0.9 \
--kv-transfer-config \
'{"kv_connector": "AscendSimpleConnector",
"kv_buffer_device": "npu",
"kv_role": "kv_producer",
"kv_parallel_size": 8,
"kv_port":"11001",
"kv_connector_extra_config":
{"prompt_device_ips": ["1.2.3.1", "1.2.3.2", "1.2.3.3", "1.2.3.4", "1.2.3.5", "1.2.3.6", "1.2.3.7", "1.2.3.8"],
"decode_device_ips": ["1.2.3.9", "1.2.3.10", "1.2.3.11", "1.2.3.12", "1.2.3.13", "1.2.3.14", "1.2.3.15", "1.2.3.16"],
"llmdatadist_comm_port": 26000,
"proxy_ip":"3.0.0.0",
"proxy_port":"30001",
"http_port": 10002}
}'

30
examples/run_dp_server.sh Normal file
View File

@@ -0,0 +1,30 @@
export HCCL_IF_IP=2.0.0.0
export GLOO_SOCKET_IFNAME="enp189s0f0"
export TP_SOCKET_IFNAME="enp189s0f0"
export HCCL_SOCKET_IFNAME="enp189s0f0"
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=100
export VLLM_USE_V1=0
export ASCEND_RT_VISIBLE_DEVICES=0,1
export VLLM_DP_SIZE=2
export VLLM_DP_RANK=0
export VLLM_DP_MASTER_IP="2.0.0.0"
export VLLM_DP_MASTER_PORT=40001
export VLLM_DP_PROXY_IP="2.0.0.0"
export VLLM_DP_PROXY_PORT=30002
export VLLM_DP_MONITOR_PORT=30003
export VLLM_HTTP_PORT=20001
vllm serve /data/weights/Qwen2.5-0.5B-Instruct \
--host 0.0.0.0 \
--port 20001 \
--tensor-parallel-size 1 \
--seed 1024 \
--served-model-name Qwen \
--max-model-len 2000 \
--max-num-batched-tokens 2000 \
--trust-remote-code \
--gpu-memory-utilization 0.9 \

View File

@@ -13,3 +13,7 @@ torch-npu==2.5.1
torch>=2.5.1 torch>=2.5.1
torchvision<0.21.0 torchvision<0.21.0
wheel wheel
# requirements for disaggregated prefill
msgpack
quart

View File

@@ -123,7 +123,7 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
"model_name": QUANT_MODEL, "model_name": QUANT_MODEL,
# GPU memory utilization # GPU memory utilization
"gpu_memory_utilization": 0.85 "gpu_memory_utilization": 0.8
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@@ -169,7 +169,7 @@ def test_mtp_e2e_quant_greedy_correctness(vllm_runner, common_llm_kwargs,
"model_name": FLOAT_MODEL, "model_name": FLOAT_MODEL,
# GPU memory utilization # GPU memory utilization
"gpu_memory_utilization": 0.85 "gpu_memory_utilization": 0.8
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@@ -230,7 +230,7 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
# Main model # Main model
"model_name": FLOAT_MODEL, "model_name": FLOAT_MODEL,
"gpu_memory_utilization": 0.85 "gpu_memory_utilization": 0.8
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@@ -274,7 +274,7 @@ def test_mtp_e2e_greedy_correctness_torchair_graph(
# Main model # Main model
"model_name": QUANT_MODEL, "model_name": QUANT_MODEL,
"gpu_memory_utilization": 0.85 "gpu_memory_utilization": 0.8
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@@ -322,7 +322,7 @@ def test_mtp_e2e_quant_greedy_correctness_torchair_graph(
"model_name": FLOAT_MODEL, "model_name": FLOAT_MODEL,
# GPU memory utilization # GPU memory utilization
"gpu_memory_utilization": 0.9 "gpu_memory_utilization": 0.8
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@@ -369,7 +369,7 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
"model_name": FLOAT_MODEL, "model_name": FLOAT_MODEL,
# GPU memory utilization # GPU memory utilization
"gpu_memory_utilization": 0.9 "gpu_memory_utilization": 0.8
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@@ -420,7 +420,7 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
"model_name": FLOAT_MODEL, "model_name": FLOAT_MODEL,
# GPU memory utilization # GPU memory utilization
"gpu_memory_utilization": 0.9 "gpu_memory_utilization": 0.8
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])

View File

@@ -1,6 +1,27 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from vllm.distributed.kv_transfer.kv_connector.factory import \ from vllm.distributed.kv_transfer.kv_connector.factory import \
KVConnectorFactory KVConnectorFactory
KVConnectorFactory.register_connector( KVConnectorFactory.register_connector(
"AscendHcclConnector", "vllm_ascend.distributed.llmdatadist_connector", "AscendHcclConnector", "vllm_ascend.distributed.llmdatadist_connector",
"LLMDataDistConnector") "LLMDataDistConnector")
KVConnectorFactory.register_connector(
"AscendSimpleConnector",
"vllm_ascend.distributed.kv_transfer.simple_connector", "SimpleConnector")

View File

@@ -0,0 +1,209 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import zlib
from typing import List, Optional
import llm_datadist # type: ignore
import torch
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import \
KVLookupBufferBase
from vllm.logger import init_logger
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
from vllm_ascend.distributed.kv_transfer.utils import TORCH_DTYPE_TO_NPU_DTYPE
logger = init_logger(__name__)
# Hash a string into a int32 value.
def int32_hash(data):
assert isinstance(data, str)
data = data.encode("utf-8")
return zlib.adler32(data)
class SimpleBuffer(KVLookupBufferBase):
def __init__(self, data_pipe: SimplePipe):
self.data_pipe = data_pipe
# Consumer buffer need these information to construct receiving buffer.
self.num_layers = None
self.num_heads = None
self.head_size = None
self.dtype = None
self.hidden_size = None
self.key_buffer = None
self.value_buffer = None
self.hidden_buffer = None
def insert(
self,
input_tokens: torch.Tensor,
roi: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
hidden: torch.Tensor,
req_id: str,
) -> None:
"""
seq_len: num_tokens of current request.
input_tokens: [seq_len]
roi: [seq_len]
key: [num_layers, seq_len, num_kv_heads, head_size]
value: [num_layers, seq_len, num_kv_heads, head_size]
hidden: [seq_len, hidden_size]
"""
orig_k_shape = key.shape
num_layers = orig_k_shape[0]
# unsequeeze all tensors to make first dim to 1.
# This is because D node can only pull one batch data from P.
# So we make first dim to 1 here in order to pull full data.
key = key.view(num_layers, -1).unsqueeze(0)
value = value.view(num_layers, -1).unsqueeze(0)
hidden = hidden.unsqueeze(0)
hidden_dtype = key.dtype
# initialize LLMDatadist data structure
key_desc = llm_datadist.CacheDesc(
1,
key.shape,
TORCH_DTYPE_TO_NPU_DTYPE[hidden_dtype],
seq_len_dim_index=1,
)
value_desc = llm_datadist.CacheDesc(
1,
value.shape,
TORCH_DTYPE_TO_NPU_DTYPE[hidden_dtype],
seq_len_dim_index=1,
)
hidden_desc = llm_datadist.CacheDesc(
1,
hidden.shape,
TORCH_DTYPE_TO_NPU_DTYPE[hidden_dtype],
seq_len_dim_index=-1,
)
req_id = int32_hash(req_id)
key_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
req_id, 1)
value_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
req_id, 2)
hidden_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
req_id, 3)
# Currently we use hash value of request id as key, so no need to send input_tokens
self.key_buffer = self.data_pipe.send_tensor(key, key_desc,
key_cache_key)
self.value_buffer = self.data_pipe.send_tensor(value, value_desc,
value_cache_key)
self.hidden_buffer = self.data_pipe.send_tensor(
hidden, hidden_desc, hidden_cache_key)
def drop_select(
self,
input_tokens: torch.Tensor,
roi: Optional[torch.Tensor],
req_id: str,
) -> List[Optional[torch.Tensor]]:
"""Select and *drop* KV cache entries from the lookup buffer.
The functionality is similar to the following python statements
```
ret = buffer.pop(input_tokens, roi)
return ret
```
Args:
input_tokens (torch.Tensor): token IDs.
roi (torch.Tensor): A binary mask on top of the input tokens
Returns:
A list of tensors including:
key: [num_layers, num_tokens, num_heads, head_size]
value: [num_layers, num_tokens, num_heads, head_size]
hidden_or_intermediate_states: [num_tokens, hidden_size]
roi: None (Currently we don't supported roi)
"""
orig_req_id = req_id
req_id = int32_hash(req_id)
num_tokens = input_tokens.shape[0]
kv_shape = (
1,
self.num_layers,
num_tokens * self.num_heads * self.head_size,
)
hidden_shape = (1, num_tokens, self.hidden_size)
key_desc = llm_datadist.CacheDesc(
1,
kv_shape,
TORCH_DTYPE_TO_NPU_DTYPE[self.dtype],
seq_len_dim_index=-1,
)
value_desc = llm_datadist.CacheDesc(
1,
kv_shape,
TORCH_DTYPE_TO_NPU_DTYPE[self.dtype],
seq_len_dim_index=-1,
)
hidden_desc = llm_datadist.CacheDesc(
1,
hidden_shape,
TORCH_DTYPE_TO_NPU_DTYPE[self.dtype],
seq_len_dim_index=-1,
)
key_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
req_id, 1)
value_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
req_id, 2)
hidden_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
req_id, 3)
# Deallocate buffer allocated in last round.
if self.key_buffer:
try:
self.data_pipe.deallocate_buffer(self.key_buffer)
self.data_pipe.deallocate_buffer(self.value_buffer)
self.data_pipe.deallocate_buffer(self.hidden_buffer)
except Exception as e:
logger.warning(
f"Failed to free kv cache buffer, Error code: {str(e)}")
try:
self.key_buffer, key = self.data_pipe.recv_tensor(
key_desc, key_cache_key)
self.value_buffer, value = self.data_pipe.recv_tensor(
value_desc, value_cache_key)
self.hidden_buffer, hidden = self.data_pipe.recv_tensor(
hidden_desc, hidden_cache_key)
key = key.view(self.num_layers, num_tokens, self.num_heads,
self.head_size)
value = value.view(self.num_layers, num_tokens, self.num_heads,
self.head_size)
hidden = hidden.view(num_tokens, self.hidden_size)
except Exception as e:
logger.warning(
f"Faile to receive kv cache and hidden states of request: {orig_req_id} "
f"Error is {str(e)}")
return [None, None, None, None]
return [key, value, hidden, roi]
def close(self):
pass

View File

@@ -0,0 +1,376 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
import torch_npu
import vllm.envs as vllm_envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.parallel_state import get_dp_group
from vllm.logger import logger
from vllm.sequence import IntermediateTensors
from vllm_ascend.distributed.kv_transfer.simple_buffer import SimpleBuffer
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
class SimpleConnector(KVConnectorBase):
def __init__(
self,
rank: int,
local_rank: int,
config: VllmConfig,
):
self.config = config
self.model_config = config.model_config.hf_config
self.tp_size = config.parallel_config.tensor_parallel_size
self.rank = rank
self.local_rank = local_rank
self.is_deepseek_mla = config.model_config.is_deepseek_mla
self.use_mla_opt = not vllm_envs.VLLM_MLA_DISABLE
self.n_layer = self.config.model_config.get_num_layers(
self.config.parallel_config)
self.producer_data_pipe: Optional[SimplePipe]
self.consumer_data_pipe: Optional[SimplePipe]
self.producer_buffer: Optional[SimpleBuffer]
self.consumer_buffer: Optional[SimpleBuffer]
if self.config.kv_transfer_config.is_kv_producer:
self.producer_data_pipe = SimplePipe(
rank=rank,
local_rank=local_rank,
kv_transfer_config=config.kv_transfer_config,
hostname="",
port_offset=rank,
)
self.producer_buffer = SimpleBuffer(self.producer_data_pipe)
else:
self.consumer_data_pipe = SimplePipe(
rank=rank,
local_rank=local_rank,
kv_transfer_config=config.kv_transfer_config,
hostname="",
port_offset=rank,
)
self.consumer_buffer = SimpleBuffer(self.consumer_data_pipe)
def select(
self,
input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor],
req_id: str,
) -> List[Optional[torch.Tensor]]:
assert self.consumer_buffer is not None, (
"Please initialize the "
"consumer buffer before calling select.")
return self.consumer_buffer.drop_select(input_tokens, roi, req_id)
def insert(
self,
input_tokens: torch.Tensor,
roi: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
hidden: torch.Tensor,
req_id: str,
) -> None:
assert self.producer_buffer is not None, (
"Please initialize the "
"producer buffer before calling insert.")
self.producer_buffer.insert(input_tokens, roi, keys, values, hidden,
req_id)
def send_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
model_config = self.model_config
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
# Deepseek's MLA (Multi-head Latent Attention) uses two different
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
# kv_lora_rank + qk_rope_head_dim].
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
# to a kv_cache shape of [2, num_blks, blk_size,
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
# For more details, see vllm/attention/backends/mla/common.py.
if self.is_deepseek_mla and self.use_mla_opt:
head_size = (model_config.kv_lora_rank +
model_config.qk_rope_head_dim)
num_heads = 1
elif self.is_deepseek_mla and not self.use_mla_opt:
head_size = (model_config.qk_nope_head_dim +
model_config.qk_rope_head_dim)
else:
head_size = getattr(
model_config,
"head_dim",
int(hidden_size // num_attention_heads),
)
# Enumerate over all requests and insert them one by one.
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
if start_pos >= num_prefill_tokens:
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
# - input_tokens[num_prefill_tokens:] contains decode tokens.
logger.warning("You have some decode requests while using "
"SimpleConnector. Their KVCache won't be sent.")
break
current_tokens = input_tokens_tensor[start_pos:end_pos]
keys, values = [], []
for layer_id in range(start_layer, end_layer):
kv_cache = kv_caches[layer_id - start_layer]
if self.is_deepseek_mla and self.use_mla_opt:
key_cache = kv_cache.reshape(-1, num_heads, head_size)
value_cache = kv_cache.reshape(-1, num_heads, head_size)
else:
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
values.append(value_cache[current_slot_mapping].unsqueeze(0))
# shape: [num_layers, num_tokens, num_heads, head_size]
keys = torch.cat(keys, dim=0)
values = torch.cat(values, dim=0)
cur_req_id = list(model_input.request_ids_to_seq_ids.keys())[idx]
# Currently we haven't considered situation of roi, pass None here.
self.insert(
current_tokens,
None,
keys,
values,
hidden_or_intermediate_states[start_pos:end_pos],
cur_req_id,
)
logger.info("[rank%d][P]: KV send DONE.", torch.distributed.get_rank())
def recv_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor],
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata", ]:
bypass_model_exec = True
model_config = self.model_config
# get model config
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
num_heads, head_dim = kv_caches[0].shape[-2:]
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
num_layers = end_layer - start_layer
if self.is_deepseek_mla and self.use_mla_opt:
head_size = (model_config.kv_lora_rank +
model_config.qk_rope_head_dim)
num_heads = 1
elif self.is_deepseek_mla and not self.use_mla_opt:
head_size = (model_config.qk_nope_head_dim +
model_config.qk_rope_head_dim)
else:
head_size = getattr(
model_config,
"head_dim",
int(hidden_size // num_attention_heads),
)
self.consumer_buffer.num_heads = num_heads # type: ignore
self.consumer_buffer.num_layers = num_layers # type: ignore
self.consumer_buffer.head_size = head_size # type: ignore
self.consumer_buffer.dtype = kv_caches[0].dtype # type: ignore
self.consumer_buffer.hidden_size = hidden_size # type: ignore
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
total_tokens = model_input.attn_metadata.num_prefill_tokens + model_input.attn_metadata.num_decode_tokens
hidden_or_intermediate_states_for_one_req = []
input_tokens_list = []
num_computed_tokens_list = []
start_pos_list = []
# enumerate different requests
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
if start_pos >= num_prefill_tokens:
logger.warning("You should set --enable_chunked_prefill=False "
"and --max_num_batched_tokens "
"should be equal to --max_seq_len_to_capture")
bypass_model_exec = False
assert start_pos == num_prefill_tokens
break
current_tokens = input_tokens_tensor[start_pos:end_pos]
num_tokens = slen
# collecting data for rebuilding the input
input_tokens_list.append(current_tokens)
start_pos_list.append(start_pos)
cur_req_id = list(model_input.request_ids_to_seq_ids.keys())[idx]
ret = self.select(
current_tokens,
torch.ones_like(current_tokens, dtype=bool),
cur_req_id,
)
if ret[0] is None:
# didn't find any match.
bypass_model_exec = False
num_computed_tokens_list.append(0)
continue
keys: torch.Tensor = ret[0]
values: torch.Tensor = ret[1]
hidden: torch.Tensor = ret[2]
num_computed_tokens = keys.shape[1]
num_computed_tokens_list.append(num_computed_tokens)
# check if both KV cache and the hidden states are received
# If not, need to redo the forwarding to compute missing states
if not all([(num_computed_tokens == num_tokens), hidden is not None
]):
bypass_model_exec = False
# update the end position based on how many tokens are cached.
end_pos = start_pos + num_computed_tokens
# put received KV caches into paged memory
for i in range(
model_executable.model.start_layer,
model_executable.model.end_layer,
):
kv_cache = kv_caches[i - model_executable.model.start_layer]
layer = model_executable.model.layers[i]
if self.is_deepseek_mla and self.use_mla_opt:
layer.self_attn.attn = layer.self_attn.mla_attn
key_cache = kv_cache
slots = slot_mapping[start_pos:end_pos]
sliced_key = keys[i - model_executable.model.start_layer]
torch_npu._npu_reshape_and_cache_siso(key=sliced_key,
key_cache=key_cache,
slot_indices=slots)
else:
key_cache, value_cache = kv_cache[0], kv_cache[1]
sliced_key = keys[i - model_executable.model.start_layer]
sliced_value = values[i -
model_executable.model.start_layer]
torch_npu._npu_reshape_and_cache(
key=sliced_key,
value=sliced_value,
key_cache=key_cache,
value_cache=value_cache,
slot_indices=slot_mapping[start_pos:end_pos],
)
hidden_or_intermediate_states_for_one_req.append(hidden)
if not bypass_model_exec:
# Some of the KV cache is not retrieved
# Here we will fall back to normal model forwarding
# But optionally you can adjust model_input so that you only do
# prefilling on those tokens that are missing KV caches.
if get_dp_group().world_size > 1:
bypass_model_exec = True
hidden_or_intermediate_states = torch.empty(
[total_tokens, hidden_size],
dtype=kv_caches[0].dtype,
device=kv_caches[0].device)
logger.warning(
"[Detect there is more one DP rank in this decode node, in this scenario, no recompute is expected when kv cache dose not received.]"
)
else:
logger.warning(
"[rank%d]: Failed to receive all KVs and hidden "
"states, redo model forwarding.",
torch.distributed.get_rank())
hidden_or_intermediate_states = None
else:
logger.debug(
"[rank%d]: Successfully received all KVs and hidden "
"states, skip model forwarding.",
torch.distributed.get_rank(),
)
# Can't directly concat here which might cause error when bs = 1.
# hidden_or_intermediate_states = torch.empty(total_num_tokens, hidden_size, dtype=kv_caches[0].dtype, device=kv_caches[0].device)
if len(hidden_or_intermediate_states_for_one_req) == 1:
hidden = hidden_or_intermediate_states_for_one_req[0]
tmp_indice = torch.tensor([0] * hidden.shape[0],
dtype=torch.int64).npu()
hidden_or_intermediate_states = torch.empty_like(hidden)
torch_npu.scatter_update_(
hidden_or_intermediate_states,
tmp_indice,
hidden,
axis=-1,
)
else:
hidden_or_intermediate_states = torch.cat(
hidden_or_intermediate_states_for_one_req, dim=0)
return hidden_or_intermediate_states, bypass_model_exec, model_input
def close(self):
self.producer_data_pipe.close() # type: ignore
self.consumer_data_pipe.close() # type: ignore
self.producer_buffer.close() # type: ignore
self.consumer_buffer.close() # type: ignore

View File

@@ -0,0 +1,209 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import threading
import time
from typing import Optional
import llm_datadist # type: ignore
import msgpack # type: ignore
import torch
import torch_npu
import torchair # type: ignore
import zmq # type: ignore
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.logger import init_logger
from vllm.utils import get_ip
import vllm_ascend.envs as envs
from vllm_ascend.distributed.kv_transfer.utils import NPU_DTYPE_TO_TORCH_DTYPE
logger = init_logger(__name__)
class SimplePipe(KVPipeBase):
def __init__(
self,
rank,
local_rank,
kv_transfer_config,
hostname: str = "",
port_offset: int = 0, # NPU offset in current P/D instance.
):
self.rank = rank
self.local_rank = local_rank
# Currently for 1P1D situation, we use cluster_id=0 for both Prefill and Decode
# Will change here in the future to support xPyD.
self.cluster_id = 0
self.config = kv_transfer_config
kv_connector_extra_config = kv_transfer_config.kv_connector_extra_config
kv_role = kv_transfer_config.kv_role
if kv_role == "kv_producer":
self.role = llm_datadist.LLMRole.PROMPT
elif kv_role == "kv_consumer":
self.role = llm_datadist.LLMRole.DECODER
else:
raise NotImplementedError(
"kv_role should be inside [kv_producer, kv_consumer]")
prompt_device_ips = kv_connector_extra_config.get(
"prompt_device_ips", None)
decode_device_ips = kv_connector_extra_config.get(
"decode_device_ips", None)
if prompt_device_ips is None or decode_device_ips is None:
raise ValueError(
"Please specify prompt_device_ips and decode_device_ips"
"in kv_transfer_config.kv_connector_extra_config")
p_device_num = len(prompt_device_ips)
d_device_num = len(decode_device_ips)
# When number of devices in P and D is not equal,
# we assume that device in D can be mapped to any device in P.
self.p_device_rank = self.rank % p_device_num
self.d_device_rank = self.rank % d_device_num
self.prompt_ip_list = prompt_device_ips
self.decode_ip_list = decode_device_ips
self.llmdatadist_comm_port = kv_connector_extra_config.get(
"llmdatadist_comm_port", 26000)
# LLMDataDist initializing.
self.data_dist = llm_datadist.LLMDataDist(self.role, self.cluster_id)
self._prepare_data_dist()
# Decoder needs to initialize and link cluster
if self.role == llm_datadist.LLMRole.DECODER:
self.cluster = self._make_cluster()
_, ret = self.data_dist.link_clusters([self.cluster], 20000)
logger.info(
f"rank {self.rank}, local_rank {self.local_rank} link, ret={ret}"
)
# If `proxy_ip` or `proxy_port` is `""`,
# then the ping thread will not be enabled.
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
proxy_port = self.config.get_from_extra_config("proxy_port", "")
if proxy_ip == "" or proxy_port == "":
self.proxy_address = ""
else:
self.proxy_address = proxy_ip + ":" + proxy_port
self._register_thread = None
if port_offset == 0 and self.proxy_address != "":
# Initialize zmq socket and register to proxy.
# Note that only NPU 0 of each P/D instance register to proxy.
if not hostname:
hostname = get_ip() # Get ip of current host.
port = kv_transfer_config.kv_port + port_offset
if port == 0:
raise ValueError("Port cannot be 0")
self._hostname = hostname
self._port = port
# Each card corresponds to a ZMQ address.
self.zmq_address = f"{self._hostname}:{self._port}"
self.context = zmq.Context() # type: ignore
self.router_socket = self.context.socket(
zmq.ROUTER) # type: ignore
self.router_socket.bind(f"tcp://{self.zmq_address}")
# The `http_port` must be consistent with the serving port of OpenAI.
self.http_address = (
f"{self._hostname}:"
f"{self.config.kv_connector_extra_config['http_port']}")
self._register_thread = threading.Thread(
target=self._register_to_proxy, daemon=True)
self._register_thread.start()
def _prepare_data_dist(self):
options = {
"llm.SyncKvCacheWaitTime": envs.LLMDATADIST_SYNC_CACHE_WAIT_TIME,
}
if self.role == llm_datadist.LLMRole.PROMPT:
options["ge.exec.deviceId"] = str(self.local_rank)
options["llm.listenIpInfo"] = (
f"{self.prompt_ip_list[self.p_device_rank]}:{self.llmdatadist_comm_port}"
)
else:
options["ge.exec.deviceId"] = str(self.local_rank)
print(f"prepare datadist, options: {options}")
self.data_dist.init(options)
self.kv_transfer = self.data_dist.kv_cache_manager
print(f"{self.rank} rank data dist is ready")
def _make_cluster(self):
cluster = llm_datadist.LLMClusterInfo()
cluster.remote_cluster_id = self.cluster_id
local_ip = self.decode_ip_list[self.d_device_rank]
remote_ip = self.prompt_ip_list[self.p_device_rank]
cluster.append_local_ip_info(local_ip, 0)
cluster.append_remote_ip_info(remote_ip, self.llmdatadist_comm_port)
return cluster
def _register_to_proxy(self):
sock = self.context.socket(zmq.DEALER) # type: ignore
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) # type: ignore
logger.debug("ping start, zmq_address:%s", self.zmq_address)
sock.connect(f"tcp://{self.proxy_address}")
data = {
"type": "P" if self.config.is_kv_producer else "D",
"http_address": self.http_address,
"zmq_address": self.zmq_address,
}
while True:
sock.send(msgpack.dumps(data))
time.sleep(3)
def send_tensor(
self,
tensor: Optional[torch.Tensor],
tensor_desc: llm_datadist.CacheDesc,
tensor_key: llm_datadist.CacheKey,
) -> llm_datadist.Cache:
buffer = self.kv_transfer.allocate_cache(tensor_desc, [tensor_key])
buffer_addr = buffer.per_device_tensor_addrs[0]
data_tensor = torchair.llm_datadist.create_npu_tensors(
tensor_desc.shape, tensor.dtype, buffer_addr)[0] # type: ignore
update_indices = torch.tensor(
[0] * tensor.shape[0], # type: ignore
dtype=torch.int64).npu()
torch_npu.scatter_update_(data_tensor, update_indices, tensor, axis=-1)
# Free cache_id of buffer, actual deallocate will happen after consumer performing pull_cache.
self.kv_transfer.deallocate_cache(buffer)
return buffer
def recv_tensor(
self,
tensor_desc: llm_datadist.CacheDesc,
tensor_key: llm_datadist.CacheKey,
) -> llm_datadist.Cache:
"""Note that this function only creates empty tensor on buffer addr and returns it."""
tmp_buffer = self.kv_transfer.allocate_cache(tensor_desc)
buffer_addr = tmp_buffer.per_device_tensor_addrs[0]
data_tensor = torchair.llm_datadist.create_npu_tensors(
tensor_desc.shape,
NPU_DTYPE_TO_TORCH_DTYPE[tensor_desc.data_type],
buffer_addr,
)[0]
self.kv_transfer.pull_cache(tensor_key, tmp_buffer, 0)
# tmp_buffer is allocated without key and will be deallocated here immediately.
# Free buffer here will cause accuracy problem.
# self.kv_transfer.deallocate_cache(tmp_buffer)
return tmp_buffer, data_tensor
def deallocate_buffer(self, buffer: llm_datadist.Cache):
self.kv_transfer.deallocate_cache(buffer)
def close(self):
self.data_dist.unlink_clusters([self.cluster], 5000)

View File

@@ -0,0 +1,40 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import llm_datadist # type: ignore
import torch
TORCH_DTYPE_TO_NPU_DTYPE = {
torch.half: llm_datadist.DataType.DT_FLOAT16,
torch.float16: llm_datadist.DataType.DT_FLOAT16,
torch.bfloat16: llm_datadist.DataType.DT_BF16,
torch.float: llm_datadist.DataType.DT_FLOAT,
torch.float32: llm_datadist.DataType.DT_FLOAT,
torch.int8: llm_datadist.DataType.DT_INT8,
torch.int64: llm_datadist.DataType.DT_INT64,
torch.int32: llm_datadist.DataType.DT_INT32,
}
NPU_DTYPE_TO_TORCH_DTYPE = {
llm_datadist.DataType.DT_FLOAT16: torch.half,
llm_datadist.DataType.DT_FLOAT16: torch.float16,
llm_datadist.DataType.DT_BF16: torch.bfloat16,
llm_datadist.DataType.DT_FLOAT: torch.float,
llm_datadist.DataType.DT_FLOAT: torch.float32,
llm_datadist.DataType.DT_INT8: torch.int8,
llm_datadist.DataType.DT_INT64: torch.int64,
llm_datadist.DataType.DT_INT32: torch.int32,
}

View File

@@ -33,7 +33,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_pp_group from vllm.distributed import get_dp_group, get_pp_group
from vllm.distributed.kv_transfer import get_kv_transfer_group from vllm.distributed.kv_transfer import get_kv_transfer_group
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.inputs import INPUT_REGISTRY, InputRegistry
@@ -1343,6 +1343,17 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
kv_caches=kv_caches kv_caches=kv_caches
) )
if get_dp_group().world_size > 1:
bypass_model_exec_tensor = torch.tensor(
1, dtype=torch.int32) if bypass_model_exec else torch.tensor(
0, dtype=torch.int32)
torch.distributed.all_reduce(bypass_model_exec_tensor,
op=torch.distributed.ReduceOp.MIN,
group=get_dp_group().cpu_group)
# If there is any group have not receive the necessary hidden states or kv_cache, we force all the dp group execute.
if bypass_model_exec_tensor.item() == 0:
bypass_model_exec = False
multi_modal_kwargs = model_input.multi_modal_kwargs or {} multi_modal_kwargs = model_input.multi_modal_kwargs or {}
seqlen_agnostic_kwargs = { seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids, "finished_requests_ids": model_input.finished_requests_ids,
@@ -1399,10 +1410,21 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
torch.tensor(model_forward_time + torch.tensor(model_forward_time +
orig_model_forward_time)) orig_model_forward_time))
return hidden_or_intermediate_states return hidden_or_intermediate_states
# TODO: remove the synchronize here
torch.npu.synchronize() logits = self.model.compute_logits(hidden_or_intermediate_states,
logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata)
model_input.sampling_metadata)
# Sending KV cache in distributed KV cache transfer setting
if self.need_send_kv(model_input, kv_caches):
get_kv_transfer_group().send_kv_caches_and_hidden_states(
# model_executable is used to know which layer the current
# worker is working on, so that we can send KV for only those
# layers.
model_executable,
model_input,
kv_caches,
hidden_or_intermediate_states,
)
if not self.is_driver_worker: if not self.is_driver_worker:
return [] return []

View File

@@ -18,10 +18,13 @@
# #
import gc import gc
import os
from typing import Dict, List, Optional, Set, Tuple, Type, Union from typing import Dict, List, Optional, Set, Tuple, Type, Union
import msgpack # type: ignore
import torch import torch
import torch.distributed import torch.distributed
import zmq
from torch import nn from torch import nn
from vllm import envs from vllm import envs
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
@@ -37,7 +40,7 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SequenceGroupMetadata, SequenceGroupMetadataDelta) SequenceGroupMetadata, SequenceGroupMetadataDelta)
from vllm.utils import GiB_bytes, bind_kv_cache from vllm.utils import GiB_bytes, bind_kv_cache, get_ip
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner_base import ModelRunnerBase from vllm.worker.model_runner_base import ModelRunnerBase
@@ -157,6 +160,33 @@ class NPUWorker(LocalOrDistributedWorkerBase):
else: else:
self.profiler = None self.profiler = None
self.enable_dummy_run = False
if os.getenv("VLLM_DP_PROXY_IP", None):
logger.warning("enable dummy run for the DP")
self.enable_dummy_run = True
# dp_rank = os.environ["VLLM_DP_RANK"]
dp_master_ip = os.environ["VLLM_DP_PROXY_IP"]
dp_proxy_listener_port = os.environ["VLLM_DP_PROXY_PORT"]
dp_proxy_monitor_port = os.environ["VLLM_DP_MONITOR_PORT"]
dp_proxy_listener_addr = f"{dp_master_ip}:{dp_proxy_listener_port}"
self.dp_proxy_monitor_addr = f"{dp_master_ip}:{dp_proxy_monitor_port}"
http_ip = get_ip()
port = os.environ["VLLM_HTTP_PORT"]
self.http_addr = f"{http_ip}:{port}"
context = zmq.Context() # type: ignore
sock = context.socket(zmq.DEALER) # type: ignore
logger.debug("ping dp proxy start, DP_RANK:%s", 0)
# logger.debug("ping dp proxy start, DP_RANK:%s", dp_rank)
sock.connect(f"tcp://{dp_proxy_listener_addr}")
data = {"type": "DP", "http_address": self.http_addr}
for _ in range(10):
sock.send(msgpack.dumps(data))
self.notify_socket = context.socket(zmq.PUSH) # type: ignore
self.notify_socket.connect(f"tcp://{self.dp_proxy_monitor_addr}")
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
NPUPlatform.set_device(self.device) NPUPlatform.set_device(self.device)
free_bytes_before_sleep = NPUPlatform.mem_get_info()[0] free_bytes_before_sleep = NPUPlatform.mem_get_info()[0]
@@ -375,6 +405,11 @@ class NPUWorker(LocalOrDistributedWorkerBase):
@torch.inference_mode() @torch.inference_mode()
def execute_worker(self, worker_input: WorkerInput) -> None: def execute_worker(self, worker_input: WorkerInput) -> None:
if self.enable_dummy_run:
logger.debug(
f"send notify to the dp proxy: {self.dp_proxy_monitor_addr}")
data = {"info": "notify_step", "http_address": self.http_addr}
self.notify_socket.send(msgpack.dumps(data))
virtual_engine = worker_input.virtual_engine virtual_engine = worker_input.virtual_engine
# Issue cache operations. # Issue cache operations.
if (worker_input.blocks_to_swap_in is not None if (worker_input.blocks_to_swap_in is not None