[P/D] Using the cache load operator to replace the index select operator. (#6295)
### What this PR does / why we need it?
Using the cache load operator to replace the index select operator.
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
---------
Signed-off-by: liziyu <liziyu16@huawei.com>
This commit is contained in:
@@ -735,7 +735,8 @@ class TestHelperFunctions(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
def test_ensure_zmq_send_success(self, _):
|
def test_ensure_zmq_send_success(self, _):
|
||||||
mock_socket = MagicMock()
|
mock_socket = MagicMock()
|
||||||
ensure_zmq_send(mock_socket, b"hello")
|
path = "127.0.0.1:12345"
|
||||||
|
ensure_zmq_send(mock_socket, b"hello", path)
|
||||||
mock_socket.send.assert_called_once_with(b"hello")
|
mock_socket.send.assert_called_once_with(b"hello")
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
@@ -743,10 +744,11 @@ class TestHelperFunctions(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
def test_ensure_zmq_send_retry_and_fail(self, _):
|
def test_ensure_zmq_send_retry_and_fail(self, _):
|
||||||
mock_socket = MagicMock()
|
mock_socket = MagicMock()
|
||||||
|
path = "127.0.0.1:12345"
|
||||||
mock_socket.send.side_effect = zmq.ZMQError( # type: ignore
|
mock_socket.send.side_effect = zmq.ZMQError( # type: ignore
|
||||||
"send failed")
|
"send failed")
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
ensure_zmq_send(mock_socket, b"hello", max_retries=2)
|
ensure_zmq_send(mock_socket, b"hello", path, max_retries=2)
|
||||||
self.assertEqual(mock_socket.send.call_count, 2)
|
self.assertEqual(mock_socket.send.call_count, 2)
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
@@ -759,7 +761,8 @@ class TestHelperFunctions(unittest.TestCase):
|
|||||||
mock_poller.poll.return_value = [
|
mock_poller.poll.return_value = [
|
||||||
(mock_socket, zmq.POLLIN) # type: ignore
|
(mock_socket, zmq.POLLIN) # type: ignore
|
||||||
]
|
]
|
||||||
data = ensure_zmq_recv(mock_socket, mock_poller)
|
path = "127.0.0.1:12345"
|
||||||
|
data = ensure_zmq_recv(mock_socket, mock_poller, path)
|
||||||
self.assertEqual(data, b"response")
|
self.assertEqual(data, b"response")
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
@@ -769,9 +772,11 @@ class TestHelperFunctions(unittest.TestCase):
|
|||||||
mock_socket = MagicMock()
|
mock_socket = MagicMock()
|
||||||
mock_poller = MagicMock()
|
mock_poller = MagicMock()
|
||||||
mock_poller.poll.return_value = []
|
mock_poller.poll.return_value = []
|
||||||
|
path = "127.0.0.1:12345"
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
ensure_zmq_recv(mock_socket,
|
ensure_zmq_recv(mock_socket,
|
||||||
mock_poller,
|
mock_poller,
|
||||||
|
path,
|
||||||
timeout=0.01,
|
timeout=0.01,
|
||||||
max_retries=2)
|
max_retries=2)
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import msgspec
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
import torch
|
import torch
|
||||||
|
import torch_npu
|
||||||
import zmq
|
import zmq
|
||||||
from mooncake.engine import TransferEngine # type: ignore
|
from mooncake.engine import TransferEngine # type: ignore
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@@ -79,7 +80,13 @@ class SendTask:
|
|||||||
k_cache: torch.Tensor | None = None
|
k_cache: torch.Tensor | None = None
|
||||||
v_cache: torch.Tensor | None = None
|
v_cache: torch.Tensor | None = None
|
||||||
layer_idx: int = 0
|
layer_idx: int = 0
|
||||||
|
# trans block info
|
||||||
rearrange_block_ids: list[int] | None = None
|
rearrange_block_ids: list[int] | None = None
|
||||||
|
num_blocks: int | None = None
|
||||||
|
num_tokens: int | None = None
|
||||||
|
block_table: torch.Tensor | None = None
|
||||||
|
block_len_tensor: torch.Tensor | None = None
|
||||||
|
seq_start_tensor: torch.Tensor | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -419,6 +426,7 @@ class KVCacheRecvingLayerThread(threading.Thread):
|
|||||||
class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata):
|
class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.requests: dict[str, ReqMeta] = {}
|
self.requests: dict[str, ReqMeta] = {}
|
||||||
|
self.send_task: SendTask = SendTask()
|
||||||
|
|
||||||
def add_new_req(
|
def add_new_req(
|
||||||
self,
|
self,
|
||||||
@@ -814,6 +822,7 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
self.kv_caches: dict[str, torch.Tensor] = {}
|
self.kv_caches: dict[str, torch.Tensor] = {}
|
||||||
self.side_channel_host = get_ip()
|
self.side_channel_host = get_ip()
|
||||||
self.total_layers = vllm_config.model_config.get_num_layers(vllm_config.parallel_config)
|
self.total_layers = vllm_config.model_config.get_num_layers(vllm_config.parallel_config)
|
||||||
|
self.use_mla = self.vllm_config.model_config.use_mla
|
||||||
|
|
||||||
# Handshake base port
|
# Handshake base port
|
||||||
self.side_channel_port = (
|
self.side_channel_port = (
|
||||||
@@ -1059,6 +1068,43 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
with self.kv_recv_layer_thread.lock:
|
with self.kv_recv_layer_thread.lock:
|
||||||
self.kv_recv_layer_thread.task_tracker[external_req_id] = 0
|
self.kv_recv_layer_thread.task_tracker[external_req_id] = 0
|
||||||
self.kv_recv_layer_thread.request_map[external_req_id] = req_id
|
self.kv_recv_layer_thread.request_map[external_req_id] = req_id
|
||||||
|
elif self.vllm_config.kv_transfer_config.is_kv_producer:
|
||||||
|
# select req to send
|
||||||
|
if self.use_mla or self.use_sparse:
|
||||||
|
num_need_send = self._decode_tp_size
|
||||||
|
else:
|
||||||
|
num_kv_head = self.vllm_config.model_config.hf_config.num_key_value_heads
|
||||||
|
if self.tp_size <= num_kv_head:
|
||||||
|
num_need_send = self.tp_size
|
||||||
|
else:
|
||||||
|
num_need_send = self._decode_tp_size if self._decode_tp_size >= num_kv_head else num_kv_head
|
||||||
|
num_replica_groups = self.tp_size // num_need_send if self.tp_size >= num_need_send else 1
|
||||||
|
replica_group_idx = self.tp_rank % num_replica_groups
|
||||||
|
req_ids = sorted(list(metadata.requests.keys()))
|
||||||
|
selected_req_ids = [
|
||||||
|
req_id for i, req_id in enumerate(req_ids) if i % num_replica_groups == replica_group_idx
|
||||||
|
]
|
||||||
|
request_ids = list(metadata.requests.keys())
|
||||||
|
for req_id in request_ids:
|
||||||
|
if req_id not in selected_req_ids:
|
||||||
|
metadata.requests.pop(req_id)
|
||||||
|
|
||||||
|
# update send task trans block info
|
||||||
|
if self.pd_head_ratio != 1:
|
||||||
|
send_task = metadata.send_task
|
||||||
|
send_task.rearrange_block_ids = sorted(
|
||||||
|
{block_id for req_id in selected_req_ids for block_id in metadata.requests[req_id].local_block_ids}
|
||||||
|
)
|
||||||
|
|
||||||
|
device = self.k_buffer.device # type: ignore
|
||||||
|
flat_block_ids = send_task.rearrange_block_ids
|
||||||
|
block_ids_tensor = torch.tensor(flat_block_ids, dtype=torch.int32, device=device)
|
||||||
|
send_task.num_blocks = len(flat_block_ids)
|
||||||
|
send_task.num_tokens = send_task.num_blocks * self.block_size
|
||||||
|
|
||||||
|
send_task.block_table = block_ids_tensor.view(1, -1)
|
||||||
|
send_task.block_len_tensor = torch.tensor([send_task.num_tokens], dtype=torch.int32, device=device)
|
||||||
|
send_task.seq_start_tensor = torch.tensor([0], dtype=torch.int32, device=device)
|
||||||
|
|
||||||
def save_kv_layer(
|
def save_kv_layer(
|
||||||
self,
|
self,
|
||||||
@@ -1071,48 +1117,41 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
"""MooncakeLayerwiseConnector does not save explicitly."""
|
"""MooncakeLayerwiseConnector does not save explicitly."""
|
||||||
if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys():
|
if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys():
|
||||||
# enable decode prefix cache
|
# enable decode prefix cache
|
||||||
if self.use_mla or self.use_sparse:
|
|
||||||
num_need_send = self._decode_tp_size
|
|
||||||
else:
|
|
||||||
num_kv_head = self.vllm_config.model_config.hf_config.num_key_value_heads
|
|
||||||
if self.tp_size <= num_kv_head:
|
|
||||||
num_need_send = self.tp_size
|
|
||||||
else:
|
|
||||||
num_need_send = self._decode_tp_size if self._decode_tp_size >= num_kv_head else num_kv_head
|
|
||||||
num_replica_groups = self.tp_size // num_need_send if self.tp_size >= num_need_send else 1
|
|
||||||
replica_group_idx = self.tp_rank % num_replica_groups
|
|
||||||
req_ids = sorted(list(connector_metadata.requests.keys()))
|
|
||||||
selected_req_ids = [
|
|
||||||
req_id for i, req_id in enumerate(req_ids) if i % num_replica_groups == replica_group_idx
|
|
||||||
]
|
|
||||||
if selected_req_ids:
|
|
||||||
if self.use_mla or self.use_sparse:
|
if self.use_mla or self.use_sparse:
|
||||||
reshape_cache_event = attn_metadata[layer_name].reshape_cache_event
|
reshape_cache_event = attn_metadata[layer_name].reshape_cache_event
|
||||||
else:
|
else:
|
||||||
reshape_cache_event = attn_metadata.reshape_cache_event
|
reshape_cache_event = attn_metadata.reshape_cache_event
|
||||||
|
|
||||||
|
send_task = connector_metadata.send_task
|
||||||
if self.pd_head_ratio != 1:
|
if self.pd_head_ratio != 1:
|
||||||
assert self.resharding_stream is not None
|
assert self.resharding_stream is not None
|
||||||
with npu_stream_switch(self.resharding_stream):
|
with npu_stream_switch(self.resharding_stream):
|
||||||
reshape_cache_event.wait()
|
reshape_cache_event.wait()
|
||||||
rearrange_block_ids = sorted(
|
dtype = self.k_buffer.dtype # type: ignore
|
||||||
{
|
device = self.k_buffer.device # type: ignore
|
||||||
block_id
|
# Initialize buffers
|
||||||
for req_id in selected_req_ids
|
keys = torch.empty((send_task.num_tokens, *kv_layer[0].size()[-2:]), dtype=dtype, device=device)
|
||||||
for block_id in connector_metadata.requests[req_id].local_block_ids
|
values = torch.empty((send_task.num_tokens, *kv_layer[1].size()[-2:]), dtype=dtype, device=device)
|
||||||
}
|
|
||||||
|
# Load cache data into buffers
|
||||||
|
torch_npu.atb.npu_paged_cache_load(
|
||||||
|
kv_layer[0],
|
||||||
|
kv_layer[1],
|
||||||
|
send_task.block_table,
|
||||||
|
send_task.block_len_tensor,
|
||||||
|
seq_starts=send_task.seq_start_tensor,
|
||||||
|
key=keys,
|
||||||
|
value=values,
|
||||||
)
|
)
|
||||||
|
|
||||||
keys = kv_layer[0][rearrange_block_ids].clone()
|
|
||||||
values = kv_layer[1][rearrange_block_ids].clone()
|
|
||||||
# sort kv caches for each block
|
# sort kv caches for each block
|
||||||
keys = (
|
keys = (
|
||||||
keys.view(keys.size(0), self.pd_head_ratio, -1, *keys.shape[2:])
|
keys.view(send_task.num_blocks, self.pd_head_ratio, -1, *keys.shape[1:])
|
||||||
.transpose(0, 1)
|
.transpose(0, 1)
|
||||||
.reshape_as(keys)
|
.reshape_as(keys)
|
||||||
)
|
)
|
||||||
values = (
|
values = (
|
||||||
values.view(values.size(0), self.pd_head_ratio, -1, *values.shape[2:])
|
values.view(send_task.num_blocks, self.pd_head_ratio, -1, *values.shape[1:])
|
||||||
.transpose(0, 1)
|
.transpose(0, 1)
|
||||||
.reshape_as(values)
|
.reshape_as(values)
|
||||||
)
|
)
|
||||||
@@ -1123,24 +1162,22 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
else:
|
else:
|
||||||
keys = None
|
keys = None
|
||||||
values = None
|
values = None
|
||||||
rearrange_block_ids = None
|
|
||||||
|
|
||||||
assert self.kv_send_layer_thread is not None
|
assert self.kv_send_layer_thread is not None
|
||||||
assert reshape_cache_event is not None
|
assert reshape_cache_event is not None
|
||||||
send_task = SendTask(
|
layer_send_task = SendTask(
|
||||||
wait_event=reshape_cache_event,
|
wait_event=reshape_cache_event,
|
||||||
k_cache=keys,
|
k_cache=keys,
|
||||||
v_cache=values,
|
v_cache=values,
|
||||||
layer_idx=self.current_layer,
|
layer_idx=self.current_layer,
|
||||||
rearrange_block_ids=rearrange_block_ids,
|
rearrange_block_ids=send_task.rearrange_block_ids,
|
||||||
)
|
)
|
||||||
for req_id, req_meta in connector_metadata.requests.items():
|
for req_id, req_meta in connector_metadata.requests.items():
|
||||||
if req_id in selected_req_ids:
|
|
||||||
req_meta_update = self.update_decoder_info(req_id, req_meta)
|
req_meta_update = self.update_decoder_info(req_id, req_meta)
|
||||||
logger.debug(f"Add request {req_id} to kv send layer thread. {req_meta_update=}")
|
logger.debug(f"Add request {req_id} to kv send layer thread. {req_meta_update=}")
|
||||||
send_task.send_request[req_id] = req_meta_update
|
layer_send_task.send_request[req_id] = req_meta_update
|
||||||
|
|
||||||
self.kv_send_layer_thread.send_queue.put(send_task)
|
self.kv_send_layer_thread.send_queue.put(layer_send_task)
|
||||||
self.current_layer += 1
|
self.current_layer += 1
|
||||||
|
|
||||||
def _get_remote_socket(self, remote_host: str, remote_handshake_port: int) -> zmq.Socket: # type: ignore
|
def _get_remote_socket(self, remote_host: str, remote_handshake_port: int) -> zmq.Socket: # type: ignore
|
||||||
@@ -1182,8 +1219,9 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
try:
|
try:
|
||||||
encoded_data = self.encoder.encode((GET_META_MSG, req_id))
|
encoded_data = self.encoder.encode((GET_META_MSG, req_id))
|
||||||
sock = self._get_remote_socket(req_meta_update.remote_host, req_meta_update.remote_port)
|
sock = self._get_remote_socket(req_meta_update.remote_host, req_meta_update.remote_port)
|
||||||
ensure_zmq_send(sock, encoded_data)
|
path = f"{req_meta_update.remote_host}:{req_meta_update.remote_port}"
|
||||||
metadata_bytes = ensure_zmq_recv(sock, self.remote_poller)
|
ensure_zmq_send(sock, encoded_data, path)
|
||||||
|
metadata_bytes = ensure_zmq_recv(sock, self.remote_poller, path)
|
||||||
agent_meta = self.decoder.decode(metadata_bytes)
|
agent_meta = self.decoder.decode(metadata_bytes)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
@@ -1231,7 +1269,7 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
msg_encoder = msgspec.msgpack.Encoder()
|
msg_encoder = msgspec.msgpack.Encoder()
|
||||||
encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id))
|
encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id))
|
||||||
with zmq_ctx(zmq.REQ, path) as sock: # type: ignore
|
with zmq_ctx(zmq.REQ, path) as sock: # type: ignore
|
||||||
ensure_zmq_send(sock, encoded_data)
|
ensure_zmq_send(sock, encoded_data, f"{req_meta.remote_host}:{req_meta.remote_port}")
|
||||||
ack = sock.recv()
|
ack = sock.recv()
|
||||||
if ack != b"ACK":
|
if ack != b"ACK":
|
||||||
raise ValueError(f"Unexpected ACK response: {ack}")
|
raise ValueError(f"Unexpected ACK response: {ack}")
|
||||||
@@ -1309,6 +1347,7 @@ def string_to_int64_hash(input_str):
|
|||||||
def ensure_zmq_send(
|
def ensure_zmq_send(
|
||||||
socket: zmq.Socket, # type: ignore
|
socket: zmq.Socket, # type: ignore
|
||||||
data: bytes,
|
data: bytes,
|
||||||
|
path: str,
|
||||||
max_retries: int = 3,
|
max_retries: int = 3,
|
||||||
):
|
):
|
||||||
retries_left = max_retries
|
retries_left = max_retries
|
||||||
@@ -1323,12 +1362,13 @@ def ensure_zmq_send(
|
|||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
else:
|
else:
|
||||||
logger.error(f"Send failed after all retries: {e}")
|
logger.error(f"Send failed after all retries: {e}")
|
||||||
raise RuntimeError(f"Failed to send data after {max_retries} retries: {e}")
|
raise RuntimeError(f"Failed to send data to {path} after {max_retries} retries: {e}")
|
||||||
|
|
||||||
|
|
||||||
def ensure_zmq_recv(
|
def ensure_zmq_recv(
|
||||||
socket: zmq.Socket, # type: ignore
|
socket: zmq.Socket, # type: ignore
|
||||||
poller: zmq.Poller, # type: ignore
|
poller: zmq.Poller, # type: ignore
|
||||||
|
path: str,
|
||||||
timeout: float = 1.0,
|
timeout: float = 1.0,
|
||||||
max_retries: int = 3,
|
max_retries: int = 3,
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
@@ -1347,7 +1387,7 @@ def ensure_zmq_recv(
|
|||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
else:
|
else:
|
||||||
logger.error(f"Receive failed after all retries: {e}")
|
logger.error(f"Receive failed after all retries: {e}")
|
||||||
raise RuntimeError(f"Failed to receive data after {max_retries} retries: {e}")
|
raise RuntimeError(f"Failed to receive data from {path} after {max_retries} retries: {e}")
|
||||||
|
|
||||||
|
|
||||||
def get_external_request_id(request_id: str):
|
def get_external_request_id(request_id: str):
|
||||||
|
|||||||
Reference in New Issue
Block a user