From d252e4f5ecd9b80c4bdcbefeb1116f5d6470d0c5 Mon Sep 17 00:00:00 2001 From: liziyu <56102866+liziyu179@users.noreply.github.com> Date: Fri, 30 Jan 2026 14:27:53 +0800 Subject: [PATCH] [P/D] Using the cache load operator to replace the index select operator. (#6295) ### What this PR does / why we need it? Using the cache load operator to replace the index select operator. - vLLM version: v0.14.1 - vLLM main: https://github.com/vllm-project/vllm/commit/dc917cceb877dfd13f98c538c4c96158047d98bd --------- Signed-off-by: liziyu --- .../test_mooncake_layerwise_connector.py | 11 +- .../kv_p2p/mooncake_layerwise_connector.py | 178 +++++++++++------- 2 files changed, 117 insertions(+), 72 deletions(-) diff --git a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py index 54275f25..f8266580 100644 --- a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py +++ b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py @@ -735,7 +735,8 @@ class TestHelperFunctions(unittest.TestCase): ) def test_ensure_zmq_send_success(self, _): mock_socket = MagicMock() - ensure_zmq_send(mock_socket, b"hello") + path = "127.0.0.1:12345" + ensure_zmq_send(mock_socket, b"hello", path) mock_socket.send.assert_called_once_with(b"hello") @patch( @@ -743,10 +744,11 @@ class TestHelperFunctions(unittest.TestCase): ) def test_ensure_zmq_send_retry_and_fail(self, _): mock_socket = MagicMock() + path = "127.0.0.1:12345" mock_socket.send.side_effect = zmq.ZMQError( # type: ignore "send failed") with self.assertRaises(RuntimeError): - ensure_zmq_send(mock_socket, b"hello", max_retries=2) + ensure_zmq_send(mock_socket, b"hello", path, max_retries=2) self.assertEqual(mock_socket.send.call_count, 2) @patch( @@ -759,7 +761,8 @@ class TestHelperFunctions(unittest.TestCase): mock_poller.poll.return_value = [ (mock_socket, zmq.POLLIN) # type: ignore ] - data = ensure_zmq_recv(mock_socket, mock_poller) + path = "127.0.0.1:12345" + data = ensure_zmq_recv(mock_socket, mock_poller, path) self.assertEqual(data, b"response") @patch( @@ -769,9 +772,11 @@ class TestHelperFunctions(unittest.TestCase): mock_socket = MagicMock() mock_poller = MagicMock() mock_poller.poll.return_value = [] + path = "127.0.0.1:12345" with self.assertRaises(RuntimeError): ensure_zmq_recv(mock_socket, mock_poller, + path, timeout=0.01, max_retries=2) diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py index 67463b90..84259848 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py @@ -19,6 +19,7 @@ import msgspec import numpy as np import numpy.typing as npt import torch +import torch_npu import zmq from mooncake.engine import TransferEngine # type: ignore from vllm.config import VllmConfig @@ -79,7 +80,13 @@ class SendTask: k_cache: torch.Tensor | None = None v_cache: torch.Tensor | None = None layer_idx: int = 0 + # trans block info rearrange_block_ids: list[int] | None = None + num_blocks: int | None = None + num_tokens: int | None = None + block_table: torch.Tensor | None = None + block_len_tensor: torch.Tensor | None = None + seq_start_tensor: torch.Tensor | None = None @dataclass @@ -419,6 +426,7 @@ class KVCacheRecvingLayerThread(threading.Thread): class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata): def __init__(self): self.requests: dict[str, ReqMeta] = {} + self.send_task: SendTask = SendTask() def add_new_req( self, @@ -814,6 +822,7 @@ class MooncakeLayerwiseConnectorWorker: self.kv_caches: dict[str, torch.Tensor] = {} self.side_channel_host = get_ip() self.total_layers = vllm_config.model_config.get_num_layers(vllm_config.parallel_config) + self.use_mla = self.vllm_config.model_config.use_mla # Handshake base port self.side_channel_port = ( @@ -1059,6 +1068,43 @@ class MooncakeLayerwiseConnectorWorker: with self.kv_recv_layer_thread.lock: self.kv_recv_layer_thread.task_tracker[external_req_id] = 0 self.kv_recv_layer_thread.request_map[external_req_id] = req_id + elif self.vllm_config.kv_transfer_config.is_kv_producer: + # select req to send + if self.use_mla or self.use_sparse: + num_need_send = self._decode_tp_size + else: + num_kv_head = self.vllm_config.model_config.hf_config.num_key_value_heads + if self.tp_size <= num_kv_head: + num_need_send = self.tp_size + else: + num_need_send = self._decode_tp_size if self._decode_tp_size >= num_kv_head else num_kv_head + num_replica_groups = self.tp_size // num_need_send if self.tp_size >= num_need_send else 1 + replica_group_idx = self.tp_rank % num_replica_groups + req_ids = sorted(list(metadata.requests.keys())) + selected_req_ids = [ + req_id for i, req_id in enumerate(req_ids) if i % num_replica_groups == replica_group_idx + ] + request_ids = list(metadata.requests.keys()) + for req_id in request_ids: + if req_id not in selected_req_ids: + metadata.requests.pop(req_id) + + # update send task trans block info + if self.pd_head_ratio != 1: + send_task = metadata.send_task + send_task.rearrange_block_ids = sorted( + {block_id for req_id in selected_req_ids for block_id in metadata.requests[req_id].local_block_ids} + ) + + device = self.k_buffer.device # type: ignore + flat_block_ids = send_task.rearrange_block_ids + block_ids_tensor = torch.tensor(flat_block_ids, dtype=torch.int32, device=device) + send_task.num_blocks = len(flat_block_ids) + send_task.num_tokens = send_task.num_blocks * self.block_size + + send_task.block_table = block_ids_tensor.view(1, -1) + send_task.block_len_tensor = torch.tensor([send_task.num_tokens], dtype=torch.int32, device=device) + send_task.seq_start_tensor = torch.tensor([0], dtype=torch.int32, device=device) def save_kv_layer( self, @@ -1072,75 +1118,66 @@ class MooncakeLayerwiseConnectorWorker: if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys(): # enable decode prefix cache if self.use_mla or self.use_sparse: - num_need_send = self._decode_tp_size + reshape_cache_event = attn_metadata[layer_name].reshape_cache_event else: - num_kv_head = self.vllm_config.model_config.hf_config.num_key_value_heads - if self.tp_size <= num_kv_head: - num_need_send = self.tp_size - else: - num_need_send = self._decode_tp_size if self._decode_tp_size >= num_kv_head else num_kv_head - num_replica_groups = self.tp_size // num_need_send if self.tp_size >= num_need_send else 1 - replica_group_idx = self.tp_rank % num_replica_groups - req_ids = sorted(list(connector_metadata.requests.keys())) - selected_req_ids = [ - req_id for i, req_id in enumerate(req_ids) if i % num_replica_groups == replica_group_idx - ] - if selected_req_ids: - if self.use_mla or self.use_sparse: - reshape_cache_event = attn_metadata[layer_name].reshape_cache_event - else: - reshape_cache_event = attn_metadata.reshape_cache_event + reshape_cache_event = attn_metadata.reshape_cache_event - if self.pd_head_ratio != 1: - assert self.resharding_stream is not None - with npu_stream_switch(self.resharding_stream): - reshape_cache_event.wait() - rearrange_block_ids = sorted( - { - block_id - for req_id in selected_req_ids - for block_id in connector_metadata.requests[req_id].local_block_ids - } - ) + send_task = connector_metadata.send_task + if self.pd_head_ratio != 1: + assert self.resharding_stream is not None + with npu_stream_switch(self.resharding_stream): + reshape_cache_event.wait() + dtype = self.k_buffer.dtype # type: ignore + device = self.k_buffer.device # type: ignore + # Initialize buffers + keys = torch.empty((send_task.num_tokens, *kv_layer[0].size()[-2:]), dtype=dtype, device=device) + values = torch.empty((send_task.num_tokens, *kv_layer[1].size()[-2:]), dtype=dtype, device=device) - keys = kv_layer[0][rearrange_block_ids].clone() - values = kv_layer[1][rearrange_block_ids].clone() - # sort kv caches for each block - keys = ( - keys.view(keys.size(0), self.pd_head_ratio, -1, *keys.shape[2:]) - .transpose(0, 1) - .reshape_as(keys) - ) - values = ( - values.view(values.size(0), self.pd_head_ratio, -1, *values.shape[2:]) - .transpose(0, 1) - .reshape_as(values) - ) - # reshard kv cache - keys = keys.reshape(-1, *kv_layer[0].shape[2:]) - values = values.reshape(-1, *kv_layer[1].shape[2:]) - (keys, values) = kv_alltoall_and_rearrange(self.pd_head_ratio, keys, values) - else: - keys = None - values = None - rearrange_block_ids = None + # Load cache data into buffers + torch_npu.atb.npu_paged_cache_load( + kv_layer[0], + kv_layer[1], + send_task.block_table, + send_task.block_len_tensor, + seq_starts=send_task.seq_start_tensor, + key=keys, + value=values, + ) - assert self.kv_send_layer_thread is not None - assert reshape_cache_event is not None - send_task = SendTask( - wait_event=reshape_cache_event, - k_cache=keys, - v_cache=values, - layer_idx=self.current_layer, - rearrange_block_ids=rearrange_block_ids, - ) - for req_id, req_meta in connector_metadata.requests.items(): - if req_id in selected_req_ids: - req_meta_update = self.update_decoder_info(req_id, req_meta) - logger.debug(f"Add request {req_id} to kv send layer thread. {req_meta_update=}") - send_task.send_request[req_id] = req_meta_update + # sort kv caches for each block + keys = ( + keys.view(send_task.num_blocks, self.pd_head_ratio, -1, *keys.shape[1:]) + .transpose(0, 1) + .reshape_as(keys) + ) + values = ( + values.view(send_task.num_blocks, self.pd_head_ratio, -1, *values.shape[1:]) + .transpose(0, 1) + .reshape_as(values) + ) + # reshard kv cache + keys = keys.reshape(-1, *kv_layer[0].shape[2:]) + values = values.reshape(-1, *kv_layer[1].shape[2:]) + (keys, values) = kv_alltoall_and_rearrange(self.pd_head_ratio, keys, values) + else: + keys = None + values = None - self.kv_send_layer_thread.send_queue.put(send_task) + assert self.kv_send_layer_thread is not None + assert reshape_cache_event is not None + layer_send_task = SendTask( + wait_event=reshape_cache_event, + k_cache=keys, + v_cache=values, + layer_idx=self.current_layer, + rearrange_block_ids=send_task.rearrange_block_ids, + ) + for req_id, req_meta in connector_metadata.requests.items(): + req_meta_update = self.update_decoder_info(req_id, req_meta) + logger.debug(f"Add request {req_id} to kv send layer thread. {req_meta_update=}") + layer_send_task.send_request[req_id] = req_meta_update + + self.kv_send_layer_thread.send_queue.put(layer_send_task) self.current_layer += 1 def _get_remote_socket(self, remote_host: str, remote_handshake_port: int) -> zmq.Socket: # type: ignore @@ -1182,8 +1219,9 @@ class MooncakeLayerwiseConnectorWorker: try: encoded_data = self.encoder.encode((GET_META_MSG, req_id)) sock = self._get_remote_socket(req_meta_update.remote_host, req_meta_update.remote_port) - ensure_zmq_send(sock, encoded_data) - metadata_bytes = ensure_zmq_recv(sock, self.remote_poller) + path = f"{req_meta_update.remote_host}:{req_meta_update.remote_port}" + ensure_zmq_send(sock, encoded_data, path) + metadata_bytes = ensure_zmq_recv(sock, self.remote_poller, path) agent_meta = self.decoder.decode(metadata_bytes) except Exception as e: logger.error( @@ -1231,7 +1269,7 @@ class MooncakeLayerwiseConnectorWorker: msg_encoder = msgspec.msgpack.Encoder() encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id)) with zmq_ctx(zmq.REQ, path) as sock: # type: ignore - ensure_zmq_send(sock, encoded_data) + ensure_zmq_send(sock, encoded_data, f"{req_meta.remote_host}:{req_meta.remote_port}") ack = sock.recv() if ack != b"ACK": raise ValueError(f"Unexpected ACK response: {ack}") @@ -1309,6 +1347,7 @@ def string_to_int64_hash(input_str): def ensure_zmq_send( socket: zmq.Socket, # type: ignore data: bytes, + path: str, max_retries: int = 3, ): retries_left = max_retries @@ -1323,12 +1362,13 @@ def ensure_zmq_send( time.sleep(0.1) else: logger.error(f"Send failed after all retries: {e}") - raise RuntimeError(f"Failed to send data after {max_retries} retries: {e}") + raise RuntimeError(f"Failed to send data to {path} after {max_retries} retries: {e}") def ensure_zmq_recv( socket: zmq.Socket, # type: ignore poller: zmq.Poller, # type: ignore + path: str, timeout: float = 1.0, max_retries: int = 3, ) -> bytes: @@ -1347,7 +1387,7 @@ def ensure_zmq_recv( time.sleep(0.1) else: logger.error(f"Receive failed after all retries: {e}") - raise RuntimeError(f"Failed to receive data after {max_retries} retries: {e}") + raise RuntimeError(f"Failed to receive data from {path} after {max_retries} retries: {e}") def get_external_request_id(request_id: str):