[P/D] Using the cache load operator to replace the index select operator. (#6295)
### What this PR does / why we need it?
Using the cache load operator to replace the index select operator.
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
---------
Signed-off-by: liziyu <liziyu16@huawei.com>
This commit is contained in:
@@ -735,7 +735,8 @@ class TestHelperFunctions(unittest.TestCase):
|
||||
)
|
||||
def test_ensure_zmq_send_success(self, _):
|
||||
mock_socket = MagicMock()
|
||||
ensure_zmq_send(mock_socket, b"hello")
|
||||
path = "127.0.0.1:12345"
|
||||
ensure_zmq_send(mock_socket, b"hello", path)
|
||||
mock_socket.send.assert_called_once_with(b"hello")
|
||||
|
||||
@patch(
|
||||
@@ -743,10 +744,11 @@ class TestHelperFunctions(unittest.TestCase):
|
||||
)
|
||||
def test_ensure_zmq_send_retry_and_fail(self, _):
|
||||
mock_socket = MagicMock()
|
||||
path = "127.0.0.1:12345"
|
||||
mock_socket.send.side_effect = zmq.ZMQError( # type: ignore
|
||||
"send failed")
|
||||
with self.assertRaises(RuntimeError):
|
||||
ensure_zmq_send(mock_socket, b"hello", max_retries=2)
|
||||
ensure_zmq_send(mock_socket, b"hello", path, max_retries=2)
|
||||
self.assertEqual(mock_socket.send.call_count, 2)
|
||||
|
||||
@patch(
|
||||
@@ -759,7 +761,8 @@ class TestHelperFunctions(unittest.TestCase):
|
||||
mock_poller.poll.return_value = [
|
||||
(mock_socket, zmq.POLLIN) # type: ignore
|
||||
]
|
||||
data = ensure_zmq_recv(mock_socket, mock_poller)
|
||||
path = "127.0.0.1:12345"
|
||||
data = ensure_zmq_recv(mock_socket, mock_poller, path)
|
||||
self.assertEqual(data, b"response")
|
||||
|
||||
@patch(
|
||||
@@ -769,9 +772,11 @@ class TestHelperFunctions(unittest.TestCase):
|
||||
mock_socket = MagicMock()
|
||||
mock_poller = MagicMock()
|
||||
mock_poller.poll.return_value = []
|
||||
path = "127.0.0.1:12345"
|
||||
with self.assertRaises(RuntimeError):
|
||||
ensure_zmq_recv(mock_socket,
|
||||
mock_poller,
|
||||
path,
|
||||
timeout=0.01,
|
||||
max_retries=2)
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import msgspec
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
import torch_npu
|
||||
import zmq
|
||||
from mooncake.engine import TransferEngine # type: ignore
|
||||
from vllm.config import VllmConfig
|
||||
@@ -79,7 +80,13 @@ class SendTask:
|
||||
k_cache: torch.Tensor | None = None
|
||||
v_cache: torch.Tensor | None = None
|
||||
layer_idx: int = 0
|
||||
# trans block info
|
||||
rearrange_block_ids: list[int] | None = None
|
||||
num_blocks: int | None = None
|
||||
num_tokens: int | None = None
|
||||
block_table: torch.Tensor | None = None
|
||||
block_len_tensor: torch.Tensor | None = None
|
||||
seq_start_tensor: torch.Tensor | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -419,6 +426,7 @@ class KVCacheRecvingLayerThread(threading.Thread):
|
||||
class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata):
|
||||
def __init__(self):
|
||||
self.requests: dict[str, ReqMeta] = {}
|
||||
self.send_task: SendTask = SendTask()
|
||||
|
||||
def add_new_req(
|
||||
self,
|
||||
@@ -814,6 +822,7 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
self.kv_caches: dict[str, torch.Tensor] = {}
|
||||
self.side_channel_host = get_ip()
|
||||
self.total_layers = vllm_config.model_config.get_num_layers(vllm_config.parallel_config)
|
||||
self.use_mla = self.vllm_config.model_config.use_mla
|
||||
|
||||
# Handshake base port
|
||||
self.side_channel_port = (
|
||||
@@ -1059,6 +1068,43 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
with self.kv_recv_layer_thread.lock:
|
||||
self.kv_recv_layer_thread.task_tracker[external_req_id] = 0
|
||||
self.kv_recv_layer_thread.request_map[external_req_id] = req_id
|
||||
elif self.vllm_config.kv_transfer_config.is_kv_producer:
|
||||
# select req to send
|
||||
if self.use_mla or self.use_sparse:
|
||||
num_need_send = self._decode_tp_size
|
||||
else:
|
||||
num_kv_head = self.vllm_config.model_config.hf_config.num_key_value_heads
|
||||
if self.tp_size <= num_kv_head:
|
||||
num_need_send = self.tp_size
|
||||
else:
|
||||
num_need_send = self._decode_tp_size if self._decode_tp_size >= num_kv_head else num_kv_head
|
||||
num_replica_groups = self.tp_size // num_need_send if self.tp_size >= num_need_send else 1
|
||||
replica_group_idx = self.tp_rank % num_replica_groups
|
||||
req_ids = sorted(list(metadata.requests.keys()))
|
||||
selected_req_ids = [
|
||||
req_id for i, req_id in enumerate(req_ids) if i % num_replica_groups == replica_group_idx
|
||||
]
|
||||
request_ids = list(metadata.requests.keys())
|
||||
for req_id in request_ids:
|
||||
if req_id not in selected_req_ids:
|
||||
metadata.requests.pop(req_id)
|
||||
|
||||
# update send task trans block info
|
||||
if self.pd_head_ratio != 1:
|
||||
send_task = metadata.send_task
|
||||
send_task.rearrange_block_ids = sorted(
|
||||
{block_id for req_id in selected_req_ids for block_id in metadata.requests[req_id].local_block_ids}
|
||||
)
|
||||
|
||||
device = self.k_buffer.device # type: ignore
|
||||
flat_block_ids = send_task.rearrange_block_ids
|
||||
block_ids_tensor = torch.tensor(flat_block_ids, dtype=torch.int32, device=device)
|
||||
send_task.num_blocks = len(flat_block_ids)
|
||||
send_task.num_tokens = send_task.num_blocks * self.block_size
|
||||
|
||||
send_task.block_table = block_ids_tensor.view(1, -1)
|
||||
send_task.block_len_tensor = torch.tensor([send_task.num_tokens], dtype=torch.int32, device=device)
|
||||
send_task.seq_start_tensor = torch.tensor([0], dtype=torch.int32, device=device)
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
@@ -1072,75 +1118,66 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys():
|
||||
# enable decode prefix cache
|
||||
if self.use_mla or self.use_sparse:
|
||||
num_need_send = self._decode_tp_size
|
||||
reshape_cache_event = attn_metadata[layer_name].reshape_cache_event
|
||||
else:
|
||||
num_kv_head = self.vllm_config.model_config.hf_config.num_key_value_heads
|
||||
if self.tp_size <= num_kv_head:
|
||||
num_need_send = self.tp_size
|
||||
else:
|
||||
num_need_send = self._decode_tp_size if self._decode_tp_size >= num_kv_head else num_kv_head
|
||||
num_replica_groups = self.tp_size // num_need_send if self.tp_size >= num_need_send else 1
|
||||
replica_group_idx = self.tp_rank % num_replica_groups
|
||||
req_ids = sorted(list(connector_metadata.requests.keys()))
|
||||
selected_req_ids = [
|
||||
req_id for i, req_id in enumerate(req_ids) if i % num_replica_groups == replica_group_idx
|
||||
]
|
||||
if selected_req_ids:
|
||||
if self.use_mla or self.use_sparse:
|
||||
reshape_cache_event = attn_metadata[layer_name].reshape_cache_event
|
||||
else:
|
||||
reshape_cache_event = attn_metadata.reshape_cache_event
|
||||
reshape_cache_event = attn_metadata.reshape_cache_event
|
||||
|
||||
if self.pd_head_ratio != 1:
|
||||
assert self.resharding_stream is not None
|
||||
with npu_stream_switch(self.resharding_stream):
|
||||
reshape_cache_event.wait()
|
||||
rearrange_block_ids = sorted(
|
||||
{
|
||||
block_id
|
||||
for req_id in selected_req_ids
|
||||
for block_id in connector_metadata.requests[req_id].local_block_ids
|
||||
}
|
||||
)
|
||||
send_task = connector_metadata.send_task
|
||||
if self.pd_head_ratio != 1:
|
||||
assert self.resharding_stream is not None
|
||||
with npu_stream_switch(self.resharding_stream):
|
||||
reshape_cache_event.wait()
|
||||
dtype = self.k_buffer.dtype # type: ignore
|
||||
device = self.k_buffer.device # type: ignore
|
||||
# Initialize buffers
|
||||
keys = torch.empty((send_task.num_tokens, *kv_layer[0].size()[-2:]), dtype=dtype, device=device)
|
||||
values = torch.empty((send_task.num_tokens, *kv_layer[1].size()[-2:]), dtype=dtype, device=device)
|
||||
|
||||
keys = kv_layer[0][rearrange_block_ids].clone()
|
||||
values = kv_layer[1][rearrange_block_ids].clone()
|
||||
# sort kv caches for each block
|
||||
keys = (
|
||||
keys.view(keys.size(0), self.pd_head_ratio, -1, *keys.shape[2:])
|
||||
.transpose(0, 1)
|
||||
.reshape_as(keys)
|
||||
)
|
||||
values = (
|
||||
values.view(values.size(0), self.pd_head_ratio, -1, *values.shape[2:])
|
||||
.transpose(0, 1)
|
||||
.reshape_as(values)
|
||||
)
|
||||
# reshard kv cache
|
||||
keys = keys.reshape(-1, *kv_layer[0].shape[2:])
|
||||
values = values.reshape(-1, *kv_layer[1].shape[2:])
|
||||
(keys, values) = kv_alltoall_and_rearrange(self.pd_head_ratio, keys, values)
|
||||
else:
|
||||
keys = None
|
||||
values = None
|
||||
rearrange_block_ids = None
|
||||
# Load cache data into buffers
|
||||
torch_npu.atb.npu_paged_cache_load(
|
||||
kv_layer[0],
|
||||
kv_layer[1],
|
||||
send_task.block_table,
|
||||
send_task.block_len_tensor,
|
||||
seq_starts=send_task.seq_start_tensor,
|
||||
key=keys,
|
||||
value=values,
|
||||
)
|
||||
|
||||
assert self.kv_send_layer_thread is not None
|
||||
assert reshape_cache_event is not None
|
||||
send_task = SendTask(
|
||||
wait_event=reshape_cache_event,
|
||||
k_cache=keys,
|
||||
v_cache=values,
|
||||
layer_idx=self.current_layer,
|
||||
rearrange_block_ids=rearrange_block_ids,
|
||||
)
|
||||
for req_id, req_meta in connector_metadata.requests.items():
|
||||
if req_id in selected_req_ids:
|
||||
req_meta_update = self.update_decoder_info(req_id, req_meta)
|
||||
logger.debug(f"Add request {req_id} to kv send layer thread. {req_meta_update=}")
|
||||
send_task.send_request[req_id] = req_meta_update
|
||||
# sort kv caches for each block
|
||||
keys = (
|
||||
keys.view(send_task.num_blocks, self.pd_head_ratio, -1, *keys.shape[1:])
|
||||
.transpose(0, 1)
|
||||
.reshape_as(keys)
|
||||
)
|
||||
values = (
|
||||
values.view(send_task.num_blocks, self.pd_head_ratio, -1, *values.shape[1:])
|
||||
.transpose(0, 1)
|
||||
.reshape_as(values)
|
||||
)
|
||||
# reshard kv cache
|
||||
keys = keys.reshape(-1, *kv_layer[0].shape[2:])
|
||||
values = values.reshape(-1, *kv_layer[1].shape[2:])
|
||||
(keys, values) = kv_alltoall_and_rearrange(self.pd_head_ratio, keys, values)
|
||||
else:
|
||||
keys = None
|
||||
values = None
|
||||
|
||||
self.kv_send_layer_thread.send_queue.put(send_task)
|
||||
assert self.kv_send_layer_thread is not None
|
||||
assert reshape_cache_event is not None
|
||||
layer_send_task = SendTask(
|
||||
wait_event=reshape_cache_event,
|
||||
k_cache=keys,
|
||||
v_cache=values,
|
||||
layer_idx=self.current_layer,
|
||||
rearrange_block_ids=send_task.rearrange_block_ids,
|
||||
)
|
||||
for req_id, req_meta in connector_metadata.requests.items():
|
||||
req_meta_update = self.update_decoder_info(req_id, req_meta)
|
||||
logger.debug(f"Add request {req_id} to kv send layer thread. {req_meta_update=}")
|
||||
layer_send_task.send_request[req_id] = req_meta_update
|
||||
|
||||
self.kv_send_layer_thread.send_queue.put(layer_send_task)
|
||||
self.current_layer += 1
|
||||
|
||||
def _get_remote_socket(self, remote_host: str, remote_handshake_port: int) -> zmq.Socket: # type: ignore
|
||||
@@ -1182,8 +1219,9 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
try:
|
||||
encoded_data = self.encoder.encode((GET_META_MSG, req_id))
|
||||
sock = self._get_remote_socket(req_meta_update.remote_host, req_meta_update.remote_port)
|
||||
ensure_zmq_send(sock, encoded_data)
|
||||
metadata_bytes = ensure_zmq_recv(sock, self.remote_poller)
|
||||
path = f"{req_meta_update.remote_host}:{req_meta_update.remote_port}"
|
||||
ensure_zmq_send(sock, encoded_data, path)
|
||||
metadata_bytes = ensure_zmq_recv(sock, self.remote_poller, path)
|
||||
agent_meta = self.decoder.decode(metadata_bytes)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@@ -1231,7 +1269,7 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
msg_encoder = msgspec.msgpack.Encoder()
|
||||
encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id))
|
||||
with zmq_ctx(zmq.REQ, path) as sock: # type: ignore
|
||||
ensure_zmq_send(sock, encoded_data)
|
||||
ensure_zmq_send(sock, encoded_data, f"{req_meta.remote_host}:{req_meta.remote_port}")
|
||||
ack = sock.recv()
|
||||
if ack != b"ACK":
|
||||
raise ValueError(f"Unexpected ACK response: {ack}")
|
||||
@@ -1309,6 +1347,7 @@ def string_to_int64_hash(input_str):
|
||||
def ensure_zmq_send(
|
||||
socket: zmq.Socket, # type: ignore
|
||||
data: bytes,
|
||||
path: str,
|
||||
max_retries: int = 3,
|
||||
):
|
||||
retries_left = max_retries
|
||||
@@ -1323,12 +1362,13 @@ def ensure_zmq_send(
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
logger.error(f"Send failed after all retries: {e}")
|
||||
raise RuntimeError(f"Failed to send data after {max_retries} retries: {e}")
|
||||
raise RuntimeError(f"Failed to send data to {path} after {max_retries} retries: {e}")
|
||||
|
||||
|
||||
def ensure_zmq_recv(
|
||||
socket: zmq.Socket, # type: ignore
|
||||
poller: zmq.Poller, # type: ignore
|
||||
path: str,
|
||||
timeout: float = 1.0,
|
||||
max_retries: int = 3,
|
||||
) -> bytes:
|
||||
@@ -1347,7 +1387,7 @@ def ensure_zmq_recv(
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
logger.error(f"Receive failed after all retries: {e}")
|
||||
raise RuntimeError(f"Failed to receive data after {max_retries} retries: {e}")
|
||||
raise RuntimeError(f"Failed to receive data from {path} after {max_retries} retries: {e}")
|
||||
|
||||
|
||||
def get_external_request_id(request_id: str):
|
||||
|
||||
Reference in New Issue
Block a user