[P/D] Using the cache load operator to replace the index select operator. (#6295)

### What this PR does / why we need it?
Using the cache load operator to replace the index select operator.

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

---------

Signed-off-by: liziyu <liziyu16@huawei.com>
This commit is contained in:
liziyu
2026-01-30 14:27:53 +08:00
committed by GitHub
parent 70cc5f7969
commit d252e4f5ec
2 changed files with 117 additions and 72 deletions

View File

@@ -735,7 +735,8 @@ class TestHelperFunctions(unittest.TestCase):
)
def test_ensure_zmq_send_success(self, _):
mock_socket = MagicMock()
ensure_zmq_send(mock_socket, b"hello")
path = "127.0.0.1:12345"
ensure_zmq_send(mock_socket, b"hello", path)
mock_socket.send.assert_called_once_with(b"hello")
@patch(
@@ -743,10 +744,11 @@ class TestHelperFunctions(unittest.TestCase):
)
def test_ensure_zmq_send_retry_and_fail(self, _):
mock_socket = MagicMock()
path = "127.0.0.1:12345"
mock_socket.send.side_effect = zmq.ZMQError( # type: ignore
"send failed")
with self.assertRaises(RuntimeError):
ensure_zmq_send(mock_socket, b"hello", max_retries=2)
ensure_zmq_send(mock_socket, b"hello", path, max_retries=2)
self.assertEqual(mock_socket.send.call_count, 2)
@patch(
@@ -759,7 +761,8 @@ class TestHelperFunctions(unittest.TestCase):
mock_poller.poll.return_value = [
(mock_socket, zmq.POLLIN) # type: ignore
]
data = ensure_zmq_recv(mock_socket, mock_poller)
path = "127.0.0.1:12345"
data = ensure_zmq_recv(mock_socket, mock_poller, path)
self.assertEqual(data, b"response")
@patch(
@@ -769,9 +772,11 @@ class TestHelperFunctions(unittest.TestCase):
mock_socket = MagicMock()
mock_poller = MagicMock()
mock_poller.poll.return_value = []
path = "127.0.0.1:12345"
with self.assertRaises(RuntimeError):
ensure_zmq_recv(mock_socket,
mock_poller,
path,
timeout=0.01,
max_retries=2)

View File

@@ -19,6 +19,7 @@ import msgspec
import numpy as np
import numpy.typing as npt
import torch
import torch_npu
import zmq
from mooncake.engine import TransferEngine # type: ignore
from vllm.config import VllmConfig
@@ -79,7 +80,13 @@ class SendTask:
k_cache: torch.Tensor | None = None
v_cache: torch.Tensor | None = None
layer_idx: int = 0
# trans block info
rearrange_block_ids: list[int] | None = None
num_blocks: int | None = None
num_tokens: int | None = None
block_table: torch.Tensor | None = None
block_len_tensor: torch.Tensor | None = None
seq_start_tensor: torch.Tensor | None = None
@dataclass
@@ -419,6 +426,7 @@ class KVCacheRecvingLayerThread(threading.Thread):
class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata):
def __init__(self):
self.requests: dict[str, ReqMeta] = {}
self.send_task: SendTask = SendTask()
def add_new_req(
self,
@@ -814,6 +822,7 @@ class MooncakeLayerwiseConnectorWorker:
self.kv_caches: dict[str, torch.Tensor] = {}
self.side_channel_host = get_ip()
self.total_layers = vllm_config.model_config.get_num_layers(vllm_config.parallel_config)
self.use_mla = self.vllm_config.model_config.use_mla
# Handshake base port
self.side_channel_port = (
@@ -1059,6 +1068,43 @@ class MooncakeLayerwiseConnectorWorker:
with self.kv_recv_layer_thread.lock:
self.kv_recv_layer_thread.task_tracker[external_req_id] = 0
self.kv_recv_layer_thread.request_map[external_req_id] = req_id
elif self.vllm_config.kv_transfer_config.is_kv_producer:
# select req to send
if self.use_mla or self.use_sparse:
num_need_send = self._decode_tp_size
else:
num_kv_head = self.vllm_config.model_config.hf_config.num_key_value_heads
if self.tp_size <= num_kv_head:
num_need_send = self.tp_size
else:
num_need_send = self._decode_tp_size if self._decode_tp_size >= num_kv_head else num_kv_head
num_replica_groups = self.tp_size // num_need_send if self.tp_size >= num_need_send else 1
replica_group_idx = self.tp_rank % num_replica_groups
req_ids = sorted(list(metadata.requests.keys()))
selected_req_ids = [
req_id for i, req_id in enumerate(req_ids) if i % num_replica_groups == replica_group_idx
]
request_ids = list(metadata.requests.keys())
for req_id in request_ids:
if req_id not in selected_req_ids:
metadata.requests.pop(req_id)
# update send task trans block info
if self.pd_head_ratio != 1:
send_task = metadata.send_task
send_task.rearrange_block_ids = sorted(
{block_id for req_id in selected_req_ids for block_id in metadata.requests[req_id].local_block_ids}
)
device = self.k_buffer.device # type: ignore
flat_block_ids = send_task.rearrange_block_ids
block_ids_tensor = torch.tensor(flat_block_ids, dtype=torch.int32, device=device)
send_task.num_blocks = len(flat_block_ids)
send_task.num_tokens = send_task.num_blocks * self.block_size
send_task.block_table = block_ids_tensor.view(1, -1)
send_task.block_len_tensor = torch.tensor([send_task.num_tokens], dtype=torch.int32, device=device)
send_task.seq_start_tensor = torch.tensor([0], dtype=torch.int32, device=device)
def save_kv_layer(
self,
@@ -1072,75 +1118,66 @@ class MooncakeLayerwiseConnectorWorker:
if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys():
# enable decode prefix cache
if self.use_mla or self.use_sparse:
num_need_send = self._decode_tp_size
reshape_cache_event = attn_metadata[layer_name].reshape_cache_event
else:
num_kv_head = self.vllm_config.model_config.hf_config.num_key_value_heads
if self.tp_size <= num_kv_head:
num_need_send = self.tp_size
else:
num_need_send = self._decode_tp_size if self._decode_tp_size >= num_kv_head else num_kv_head
num_replica_groups = self.tp_size // num_need_send if self.tp_size >= num_need_send else 1
replica_group_idx = self.tp_rank % num_replica_groups
req_ids = sorted(list(connector_metadata.requests.keys()))
selected_req_ids = [
req_id for i, req_id in enumerate(req_ids) if i % num_replica_groups == replica_group_idx
]
if selected_req_ids:
if self.use_mla or self.use_sparse:
reshape_cache_event = attn_metadata[layer_name].reshape_cache_event
else:
reshape_cache_event = attn_metadata.reshape_cache_event
reshape_cache_event = attn_metadata.reshape_cache_event
if self.pd_head_ratio != 1:
assert self.resharding_stream is not None
with npu_stream_switch(self.resharding_stream):
reshape_cache_event.wait()
rearrange_block_ids = sorted(
{
block_id
for req_id in selected_req_ids
for block_id in connector_metadata.requests[req_id].local_block_ids
}
)
send_task = connector_metadata.send_task
if self.pd_head_ratio != 1:
assert self.resharding_stream is not None
with npu_stream_switch(self.resharding_stream):
reshape_cache_event.wait()
dtype = self.k_buffer.dtype # type: ignore
device = self.k_buffer.device # type: ignore
# Initialize buffers
keys = torch.empty((send_task.num_tokens, *kv_layer[0].size()[-2:]), dtype=dtype, device=device)
values = torch.empty((send_task.num_tokens, *kv_layer[1].size()[-2:]), dtype=dtype, device=device)
keys = kv_layer[0][rearrange_block_ids].clone()
values = kv_layer[1][rearrange_block_ids].clone()
# sort kv caches for each block
keys = (
keys.view(keys.size(0), self.pd_head_ratio, -1, *keys.shape[2:])
.transpose(0, 1)
.reshape_as(keys)
)
values = (
values.view(values.size(0), self.pd_head_ratio, -1, *values.shape[2:])
.transpose(0, 1)
.reshape_as(values)
)
# reshard kv cache
keys = keys.reshape(-1, *kv_layer[0].shape[2:])
values = values.reshape(-1, *kv_layer[1].shape[2:])
(keys, values) = kv_alltoall_and_rearrange(self.pd_head_ratio, keys, values)
else:
keys = None
values = None
rearrange_block_ids = None
# Load cache data into buffers
torch_npu.atb.npu_paged_cache_load(
kv_layer[0],
kv_layer[1],
send_task.block_table,
send_task.block_len_tensor,
seq_starts=send_task.seq_start_tensor,
key=keys,
value=values,
)
assert self.kv_send_layer_thread is not None
assert reshape_cache_event is not None
send_task = SendTask(
wait_event=reshape_cache_event,
k_cache=keys,
v_cache=values,
layer_idx=self.current_layer,
rearrange_block_ids=rearrange_block_ids,
)
for req_id, req_meta in connector_metadata.requests.items():
if req_id in selected_req_ids:
req_meta_update = self.update_decoder_info(req_id, req_meta)
logger.debug(f"Add request {req_id} to kv send layer thread. {req_meta_update=}")
send_task.send_request[req_id] = req_meta_update
# sort kv caches for each block
keys = (
keys.view(send_task.num_blocks, self.pd_head_ratio, -1, *keys.shape[1:])
.transpose(0, 1)
.reshape_as(keys)
)
values = (
values.view(send_task.num_blocks, self.pd_head_ratio, -1, *values.shape[1:])
.transpose(0, 1)
.reshape_as(values)
)
# reshard kv cache
keys = keys.reshape(-1, *kv_layer[0].shape[2:])
values = values.reshape(-1, *kv_layer[1].shape[2:])
(keys, values) = kv_alltoall_and_rearrange(self.pd_head_ratio, keys, values)
else:
keys = None
values = None
self.kv_send_layer_thread.send_queue.put(send_task)
assert self.kv_send_layer_thread is not None
assert reshape_cache_event is not None
layer_send_task = SendTask(
wait_event=reshape_cache_event,
k_cache=keys,
v_cache=values,
layer_idx=self.current_layer,
rearrange_block_ids=send_task.rearrange_block_ids,
)
for req_id, req_meta in connector_metadata.requests.items():
req_meta_update = self.update_decoder_info(req_id, req_meta)
logger.debug(f"Add request {req_id} to kv send layer thread. {req_meta_update=}")
layer_send_task.send_request[req_id] = req_meta_update
self.kv_send_layer_thread.send_queue.put(layer_send_task)
self.current_layer += 1
def _get_remote_socket(self, remote_host: str, remote_handshake_port: int) -> zmq.Socket: # type: ignore
@@ -1182,8 +1219,9 @@ class MooncakeLayerwiseConnectorWorker:
try:
encoded_data = self.encoder.encode((GET_META_MSG, req_id))
sock = self._get_remote_socket(req_meta_update.remote_host, req_meta_update.remote_port)
ensure_zmq_send(sock, encoded_data)
metadata_bytes = ensure_zmq_recv(sock, self.remote_poller)
path = f"{req_meta_update.remote_host}:{req_meta_update.remote_port}"
ensure_zmq_send(sock, encoded_data, path)
metadata_bytes = ensure_zmq_recv(sock, self.remote_poller, path)
agent_meta = self.decoder.decode(metadata_bytes)
except Exception as e:
logger.error(
@@ -1231,7 +1269,7 @@ class MooncakeLayerwiseConnectorWorker:
msg_encoder = msgspec.msgpack.Encoder()
encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id))
with zmq_ctx(zmq.REQ, path) as sock: # type: ignore
ensure_zmq_send(sock, encoded_data)
ensure_zmq_send(sock, encoded_data, f"{req_meta.remote_host}:{req_meta.remote_port}")
ack = sock.recv()
if ack != b"ACK":
raise ValueError(f"Unexpected ACK response: {ack}")
@@ -1309,6 +1347,7 @@ def string_to_int64_hash(input_str):
def ensure_zmq_send(
socket: zmq.Socket, # type: ignore
data: bytes,
path: str,
max_retries: int = 3,
):
retries_left = max_retries
@@ -1323,12 +1362,13 @@ def ensure_zmq_send(
time.sleep(0.1)
else:
logger.error(f"Send failed after all retries: {e}")
raise RuntimeError(f"Failed to send data after {max_retries} retries: {e}")
raise RuntimeError(f"Failed to send data to {path} after {max_retries} retries: {e}")
def ensure_zmq_recv(
socket: zmq.Socket, # type: ignore
poller: zmq.Poller, # type: ignore
path: str,
timeout: float = 1.0,
max_retries: int = 3,
) -> bytes:
@@ -1347,7 +1387,7 @@ def ensure_zmq_recv(
time.sleep(0.1)
else:
logger.error(f"Receive failed after all retries: {e}")
raise RuntimeError(f"Failed to receive data after {max_retries} retries: {e}")
raise RuntimeError(f"Failed to receive data from {path} after {max_retries} retries: {e}")
def get_external_request_id(request_id: str):