[P/D] Performance enhancement of Layerwise connector in TP asymmetric scenarios (#5540)

### What this PR does / why we need it?
[P/D] Performance enhancement of Layerwise connector in TP asymmetric
scenarios
1. Session fusion: For transmission tasks at each layer, aggregate
transmission tasks with the same destination and merge them into a
single task for assignment.
2. Alltoall aggregation: For TP asymmetric scenarios, perform all
alltoall operations at once according to the block granularity for all
requests.

[RFC]: CDCP Scheduling for Disaggregated Prefilling with KV Cache
Layerwise Push Support
https://github.com/vllm-project/vllm-ascend/issues/4842
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1

---------

Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Co-authored-by: nwpu-zxr <zhouxuerong2@huawei.com>
Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com>
This commit is contained in:
liziyu
2026-01-06 20:25:36 +08:00
committed by GitHub
parent cd1162e25a
commit 330e25ab1d
2 changed files with 341 additions and 275 deletions

View File

@@ -1,3 +1,4 @@
import contextlib
import os
import sys
import threading
@@ -18,9 +19,9 @@ from vllm_ascend.distributed.mooncake_layerwise_connector import ( # noqa: E402
KVCacheRecvingLayerThread, KVCacheSendingLayerThread, KVConnectorRole,
MooncakeAgentMetadata, MooncakeLayerwiseConnector,
MooncakeLayerwiseConnectorMetadata, MooncakeLayerwiseConnectorScheduler,
MooncakeLayerwiseConnectorWorker, ReqMeta, SendReqInfo, ensure_zmq_recv,
ensure_zmq_send, group_concurrent_contiguous, string_to_int64_hash,
zmq_ctx)
MooncakeLayerwiseConnectorWorker, ReqMeta, SendReqInfo, SendTask,
ensure_zmq_recv, ensure_zmq_send, group_concurrent_contiguous,
string_to_int64_hash, zmq_ctx)
GET_META_MSG = b"get_meta_msg"
DONE_SENDING_MSG = b"done_sending_msg"
@@ -32,14 +33,8 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
self.engine = MagicMock()
self.engine.register_memory.return_value = 0
self.engine.batch_transfer_sync_write.return_value = 1
self._patcher_cs = patch(
'vllm_ascend.distributed.mooncake_layerwise_connector.torch_npu.npu.current_stream'
)
self.mock_current_stream = self._patcher_cs.start()
self.addCleanup(self._patcher_cs.stop)
fake_stream = MagicMock(name="FakeStream")
fake_stream.synchronize = MagicMock()
self.mock_current_stream.return_value = fake_stream
self.first_kv_cache = torch.zeros((2, 2, 2, 8),
dtype=torch.float32,
@@ -47,6 +42,14 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
self.ready_event = threading.Event()
self.fake_k_buffer = MagicMock()
self.fake_v_buffer = MagicMock()
fake_resharding_stream = MagicMock()
cap = self.first_kv_cache.numel() // self.first_kv_cache.shape[-1]
dim = self.first_kv_cache.shape[-1]
self.key = torch.zeros((cap, dim), dtype=torch.float32)
self.value = torch.zeros((cap, dim), dtype=torch.float32)
self.thread = KVCacheSendingLayerThread(
engine=self.engine,
total_layers=3,
@@ -60,6 +63,9 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
block_len=[1024, 2048],
decode_tp_size=1,
first_kv_cache=self.first_kv_cache,
k_buffer=self.fake_k_buffer,
v_buffer=self.fake_v_buffer,
resharding_stream=fake_resharding_stream,
callback_func=MagicMock())
self.req_meta_base = ReqMeta(
@@ -74,6 +80,9 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
metaserver="http://dummy",
chunk_finish=False)
@patch(
"vllm_ascend.distributed.mooncake_layerwise_connector.npu_stream_switch",
side_effect=lambda *_args, **_kwargs: contextlib.nullcontext())
@patch(
"vllm_ascend.distributed.mooncake_layerwise_connector.torch.Tensor.data_ptr",
autospec=True,
@@ -87,7 +96,10 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
"vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous"
)
def test_transfer_pd_gt1_uses_buffers_and_calls_engine(
self, mock_group, _mock_sync, _mock_align, _mock_dataptr):
self, mock_group, _mock_sync, _mock_align, _mock_dataptr,
mock_stream_switch):
fake_resharding_stream = MagicMock()
thread = KVCacheSendingLayerThread(
engine=self.engine,
@@ -101,26 +113,28 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
block_len=[64],
decode_tp_size=1,
first_kv_cache=self.first_kv_cache,
k_buffer=self.fake_k_buffer,
v_buffer=self.fake_v_buffer,
resharding_stream=fake_resharding_stream,
callback_func=MagicMock())
req_meta = self.req_meta_base
req_meta.remote_kv_caches_base_addr = [4000, 8000]
mock_group.return_value = ([[10, 11], [20, 21]], [])
key = torch.zeros((1, 8), dtype=torch.float32)
value = torch.zeros((1, 8), dtype=torch.float32)
cap = self.first_kv_cache.numel() // self.first_kv_cache.shape[-1]
dim = self.first_kv_cache.shape[-1]
send_task = SendTask(
send_request={"req1": req_meta},
wait_event=MagicMock(),
k_cache=key,
v_cache=value,
layer_idx=0,
rearrange_block_ids=[5, 8],
)
key = torch.zeros((cap, dim), dtype=torch.float32)
value = torch.zeros((cap, dim), dtype=torch.float32)
thread._transfer_kv_cache( # type: ignore
req_id="req1",
req_meta=req_meta,
layer_index=0,
key=key,
value=value,
reshape_cache_event=MagicMock())
thread._transfer_kv_cache(send_task)
self.engine.batch_transfer_sync_write.assert_called_once()
session_id, src_list, dst_list, length_list = self.engine.batch_transfer_sync_write.call_args[
@@ -145,60 +159,45 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
def test_transfer_skips_when_no_local_blocks(self):
req_meta = self.req_meta_base
req_meta.local_block_ids = []
self.thread.pd_head_ratio = 1
self.thread.block_len = [64, 128]
key = torch.zeros((1, 8), dtype=torch.float32)
value = torch.zeros((1, 8), dtype=torch.float32)
reshape_cache_event = MagicMock()
with patch.object(self.engine,
'batch_transfer_sync_write') as mock_batch_transfer:
mock_batch_transfer.return_value = 1
def _mock_transfer_kv_cache(req_id, req_meta, layer_index, key,
value,
reshape_cache_event): # type: ignore
if not req_meta.local_block_ids:
return
self._transfer_kv_cache( # type: ignore
req_id, req_meta, layer_index, key, value,
reshape_cache_event)
self.thread._transfer_kv_cache = _mock_transfer_kv_cache # type: ignore
self.thread._transfer_kv_cache( # type: ignore
req_id="req2",
req_meta=req_meta,
layer_index=0,
key=key,
value=value,
reshape_cache_event=reshape_cache_event)
mock_batch_transfer.assert_not_called()
self.assertEqual(mock_batch_transfer.call_count, 0)
send_task = SendTask(
send_request={"req2": req_meta},
wait_event=MagicMock(),
k_cache=torch.zeros((1, 8)),
v_cache=torch.zeros((1, 8)),
layer_idx=0,
rearrange_block_ids=[],
)
self.thread._transfer_kv_cache(send_task)
self.engine.batch_transfer_sync_write.assert_not_called()
def test_transfer_skips_when_tp_not_sender(self):
thread = KVCacheSendingLayerThread(engine=self.engine,
total_layers=2,
ready_event=self.ready_event,
tp_rank=1,
pd_head_ratio=1,
num_head_replica=2,
kv_cache_base_addr=[1000, 2000],
use_mla=False,
block_len=[1024],
decode_tp_size=1,
first_kv_cache=self.first_kv_cache,
callback_func=MagicMock())
thread = KVCacheSendingLayerThread(
engine=self.engine,
total_layers=2,
ready_event=self.ready_event,
tp_rank=1,
pd_head_ratio=1,
num_head_replica=2,
kv_cache_base_addr=[1000, 2000, 3000, 4000],
use_mla=False,
block_len=[1024],
decode_tp_size=1,
first_kv_cache=self.first_kv_cache,
k_buffer=MagicMock(),
v_buffer=MagicMock(),
resharding_stream=MagicMock(),
callback_func=MagicMock())
req_meta = self.req_meta_base
thread._transfer_kv_cache( # type: ignore
"req3",
req_meta,
0,
torch.zeros((1, 8)),
torch.zeros((1, 8)),
reshape_cache_event=MagicMock())
send_task = SendTask(
send_request={"req3": req_meta},
wait_event=MagicMock(),
k_cache=self.key,
v_cache=self.value,
layer_idx=1,
rearrange_block_ids=[],
)
thread._transfer_kv_cache(send_task)
self.engine.batch_transfer_sync_write.assert_not_called()
@patch(
@@ -208,30 +207,30 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
"vllm_ascend.distributed.mooncake_layerwise_connector.torch.npu.synchronize"
)
def test_callback_invoked_on_final_layer(self, _mock_sync, _mock_group):
req_meta = self.req_meta_base
req_meta.chunk_finish = True
req_meta.local_block_ids = [5, 6]
req_meta.remote_block_ids = [10, 11]
req_meta.remote_kv_caches_base_addr = [
7000, 8000, 9000, 10000, 11000, 12000
]
req_meta.chunk_finish = True
key = torch.zeros((1, 8), dtype=torch.float32)
value = torch.zeros((1, 8), dtype=torch.float32)
send_task = MagicMock()
send_task.layer_index = self.thread.total_layers - 1
send_task.send_request = {"req5": req_meta}
send_task = SendTask(
send_request={"req5": req_meta},
wait_event=MagicMock(),
k_cache=key,
v_cache=value,
layer_idx=2,
rearrange_block_ids=[],
)
self.thread._transfer_kv_cache(send_task)
with patch.object(self.thread, 'callback_func') as mock_callback_func:
self.thread._transfer_kv_cache( # type: ignore
req_id="req5",
req_meta=req_meta,
layer_index=send_task.layer_index,
key=key,
value=value,
reshape_cache_event=MagicMock())
print(f"Callback called: {mock_callback_func.call_count} times")
mock_callback_func.assert_called_once()
self.thread.callback_func.assert_called_once()
class TestKVCacheRecvingLayerThread(unittest.TestCase):
@@ -506,10 +505,10 @@ class TestMooncakeLayerwiseConnectorSchedulerMatchedTokens(unittest.TestCase):
self.assertTrue(async_flag)
def test_build_connector_meta(self):
self.scheduler.vllm_config.kv_transfer_config.is_kv_consumer = True
request = MockRequest("req1")
self.scheduler._reqs_need_recv["req1"] = (request, [], [4, 5, 6])
self.scheduler.vllm_config.kv_transfer_config.is_kv_consumer = True
request.kv_transfer_params = {
"remote_block_ids": [1, 2, 3],
"remote_engine_id": "remote",
@@ -554,9 +553,9 @@ class _MockSchedulerOutput:
new_block_ids=cached_new_block_ids or [],
num_computed_tokens=cached_num_computed or [],
)
self.scheduled_spec_decode_tokens = scheduled_spec_decode_tokens or {}
self.scheduled_new_reqs = new_reqs or []
self.num_scheduled_tokens = num_sched or {}
self.scheduled_spec_decode_tokens = scheduled_spec_decode_tokens or {}
class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
@@ -593,14 +592,12 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
self.assertFalse(req.kv_transfer_params.get("do_remote_prefill", True))
def test_update_state_after_alloc_decode_records_send_layerwise(self):
req = MockRequest(
"req_u2",
prompt_token_ids=list(range(10)),
kv_transfer_params={
"do_remote_decode": True,
"remote_block_ids": [] # 修改为空列表 []
})
req = MockRequest("req_u2",
prompt_token_ids=list(range(10)),
kv_transfer_params={
"do_remote_decode": True,
"remote_block_ids": []
})
blocks = _MockBlocks(unhashed=[], block_ids_tuple=([7, 8, 9], ))
self.scheduler.update_state_after_alloc(req,
blocks,
@@ -610,7 +607,24 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
self.assertEqual(info.local_block_ids, [7, 8, 9])
self.assertIs(info.request, req)
self.assertEqual(info.remote_block_ids, [])
self.assertIsInstance(info.remote_block_ids, list)
def test_build_connector_meta_consumes_reqs_need_recv_and_clears(self):
self.scheduler.vllm_config.kv_transfer_config.is_kv_consumer = True
req = MockRequest("req_b1",
kv_transfer_params={
"remote_block_ids": [1, 2],
"remote_engine_id": "E",
"remote_host": "H",
"remote_port": 5555,
"remote_te_rpc_port": 6000,
"remote_kv_caches_base_addr": [10, 11],
})
self.scheduler._reqs_need_recv["req_b1"] = (req, [], [100, 101])
meta = self.scheduler.build_connector_meta(_MockSchedulerOutput())
self.assertIsInstance(meta, MooncakeLayerwiseConnectorMetadata)
self.assertIn("req_b1", meta.requests)
self.assertEqual(meta.requests["req_b1"].local_block_ids, [100, 101])
self.assertEqual(len(self.scheduler._reqs_need_recv), 0)
def test_build_connector_meta_accumulates_cached_blocks(self):
req_meta = MagicMock(spec=ReqMeta)
@@ -637,8 +651,6 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
meta = self.scheduler.build_connector_meta(out)
self.assertEqual(len(meta.requests), 0)
req_meta.extend_local_block_ids.assert_called_once_with([3, 4])
@patch(
"vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous"
)
@@ -684,6 +696,12 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
send_req_info.extend_local_block_ids.assert_called_once_with([50])
self.assertIn("req_b3", meta.requests)
def test_request_finished_returns_false_none(self):
ok, params = self.scheduler.request_finished(MockRequest("req_fin"),
[1, 2])
self.assertFalse(ok)
self.assertIsNone(params)
class TestHelperFunctions(unittest.TestCase):

View File

@@ -11,7 +11,7 @@ import time
from collections import OrderedDict, defaultdict, deque
from collections.abc import Iterator
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
import httpx
@@ -19,7 +19,6 @@ import msgspec
import numpy as np
import numpy.typing as npt
import torch
import torch_npu
import zmq
from mooncake.engine import TransferEngine # type: ignore
from vllm.config import VllmConfig
@@ -37,6 +36,7 @@ from vllm_ascend.distributed.mooncake_transfer_engine import global_te
from vllm_ascend.distributed.utils import (align_memory,
get_transfer_timeout_value,
kv_alltoall_and_rearrange)
from vllm_ascend.utils import npu_stream_switch
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
@@ -56,7 +56,7 @@ class MooncakeAgentMetadata(msgspec.Struct, omit_defaults=True, dict=True):
@dataclass
class ReqMeta:
local_block_ids: list[int]
token_ids: list[int]
token_ids: Optional[list[int]]
# Not None if layer-wise is disabled
remote_block_ids: list[int]
remote_engine_id: Optional[str]
@@ -68,6 +68,26 @@ class ReqMeta:
chunk_finish: Optional[bool]
@dataclass
class SendTask:
send_request: dict[str, ReqMeta] = field(default_factory=dict)
# pd_head_ratio == 1 use
wait_event: Optional[torch.npu.Event] = None
# pd_head_ratio > 1 use
k_cache: Optional[torch.Tensor] = None
v_cache: Optional[torch.Tensor] = None
layer_idx: int = 0
rearrange_block_ids: Optional[list[int]] = None
@dataclass
class TransferMeta:
src: list[int]
dst: list[int]
length: list[int]
req_ids: list[str]
@dataclass
class SendReqInfo:
local_block_ids: list[int]
@@ -116,19 +136,24 @@ class SizedDict(OrderedDict):
class KVCacheSendingLayerThread(threading.Thread):
def __init__(self,
engine: TransferEngine,
total_layers: int,
ready_event: threading.Event,
tp_rank: int,
pd_head_ratio: int,
num_head_replica: int,
kv_cache_base_addr: list[int],
use_mla: bool,
block_len: list[int],
decode_tp_size: int,
first_kv_cache: torch.Tensor,
callback_func: Callable[..., None] = lambda x: None):
def __init__(
self,
engine: TransferEngine,
total_layers: int,
ready_event: threading.Event,
tp_rank: int,
pd_head_ratio: int,
num_head_replica: int,
kv_cache_base_addr: list[int],
use_mla: bool,
block_len: list[int],
decode_tp_size: int,
first_kv_cache: torch.Tensor,
k_buffer: torch.Tensor,
v_buffer: torch.Tensor,
resharding_stream: torch.npu.Stream,
callback_func: Callable[..., None] = lambda x: None,
):
super().__init__(daemon=True, name="KVCacheSendingLayerThread")
self.engine = engine
self.tp_rank = tp_rank
@@ -139,39 +164,12 @@ class KVCacheSendingLayerThread(threading.Thread):
self.use_mla = use_mla
self.block_len = block_len
self._decode_tp_size = decode_tp_size
self.model_stream = torch_npu.npu.current_stream()
self.resharding_stream = resharding_stream
self.current_layer = -1
if self.pd_head_ratio > 1:
# regesit kv buffer for tp inequal
alignment = 2 * 1024 * 1024
self.k_buffer = torch.zeros(first_kv_cache.numel() + alignment,
dtype=first_kv_cache.dtype,
device=first_kv_cache.device)
self.k_buffer = align_memory(
self.k_buffer, alignment)[:first_kv_cache.numel()].view(
-1, first_kv_cache.shape[-1])
self.v_buffer = torch.zeros(first_kv_cache.numel() + alignment,
dtype=first_kv_cache.dtype,
device=first_kv_cache.device)
self.v_buffer = align_memory(
self.v_buffer, alignment)[:first_kv_cache.numel()].view(
-1, first_kv_cache.shape[-1])
for tensor in (self.k_buffer, self.v_buffer):
assert tensor.data_ptr(
) % alignment == 0, "The address of the registered kv cache should be aligned to 2M"
ret_value = self.engine.register_memory(
tensor.data_ptr(), tensor.numel())
logger.info(
f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} {tensor.numel()} {ret_value=}"
)
if ret_value != 0:
raise RuntimeError("Mooncake memory registration failed. ")
self.send_queue = queue.Queue[Tuple[str, ReqMeta, int, torch.Tensor,
torch.Tensor, torch.npu.Event]]()
self.send_queue = queue.Queue[SendTask]()
self.k_buffer = k_buffer
self.v_buffer = v_buffer
self.ready_event = ready_event
self.callback_func = callback_func
@@ -181,43 +179,36 @@ class KVCacheSendingLayerThread(threading.Thread):
torch.npu.set_device(device)
self.ready_event.set()
while True:
req_id, req_meta, layer_index, key, value, reshape_cache_event = self.send_queue.get(
)
self._handle_request(req_id, req_meta, layer_index, key, value,
reshape_cache_event)
send_task = self.send_queue.get()
self._handle_request(send_task)
def _handle_request(self, req_id, req_meta, layer_index, key, value,
reshape_cache_event):
def _handle_request(self, send_task: SendTask):
try:
logger.debug(
f"Starting to transfer KV cache for request {req_id} {req_meta.remote_te_rpc_port=}."
)
self._transfer_kv_cache(req_id, req_meta, layer_index, key, value,
reshape_cache_event)
logger.debug(
f"Finished transferring KV cache for request {req_id} {req_meta.remote_te_rpc_port=}."
)
self._transfer_kv_cache(send_task)
except Exception as e:
logger.error("Failed to transfer KV cache for request "
f"{req_id}: {e}")
logger.error(
f"Failed to transfer KV cache for layer idx {send_task.layer_idx}, {e}"
)
def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value,
reshape_cache_event):
def get_transfer_meta(self, send_task: SendTask, req_id: str,
req_meta: ReqMeta):
src_list: list[str] = []
dst_list: list[str] = []
length_list: list[int] = []
# not need to send kv cache
if self.tp_rank % self.num_head_replica != 0:
logger.debug(
f"Cancelling KV cache transfer for request {req_id}. Reason: TP rank excluded from head replication (TP Rank: {self.tp_rank}, Replicas: {self.num_head_replica})."
)
return
return (src_list, dst_list, length_list)
if self.use_mla and self.tp_rank >= self._decode_tp_size:
logger.debug(
f"Cancelling KV cache transfer for request {req_id}. Reason: MLA mode active and TP rank outside decoding group (TP Rank: {self.tp_rank}, Decode TP Size: {self._decode_tp_size})."
)
return
return (src_list, dst_list, length_list)
remote_host = req_meta.remote_host
layer_idx = send_task.layer_idx
remote_block_ids = req_meta.remote_block_ids
remote_te_port = req_meta.remote_te_rpc_port
remote_kv_base_addrs = req_meta.remote_kv_caches_base_addr
local_kv_base_addr = self.kv_caches_base_addr
local_block_ids = req_meta.local_block_ids
@@ -225,17 +216,15 @@ class KVCacheSendingLayerThread(threading.Thread):
if self.pd_head_ratio == 1:
layer_local_kv_base_addr = [
local_kv_base_addr[i]
for i in [2 * layer_index, 2 * layer_index + 1]
for i in [2 * layer_idx, 2 * layer_idx + 1]
]
layer_remote_kv_base_addr = [
remote_kv_base_addrs[i]
for i in [2 * layer_index, 2 * layer_index + 1]
remote_kv_base_addrs[i] # type:ignore
for i in [2 * layer_idx, 2 * layer_idx + 1]
]
grouped_remote_block_ids, grouped_local_block_ids = \
group_concurrent_contiguous(remote_block_ids, local_block_ids)
session_id = f"{remote_host}:{remote_te_port}"
src_list, dst_list, length_list = [], [], []
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)):
block_len = self.block_len[
@@ -250,74 +239,101 @@ class KVCacheSendingLayerThread(threading.Thread):
src_list.append(src)
dst_list.append(dst)
length_list.append(length)
if self.current_layer != layer_index:
self.current_layer = layer_index
"""
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
This issue will be fixed in CANN version 8.5.rc1.
You can manually build the master branch of the project at https://gitcode.com/cann/hixl
to resolve this issue before the 8.5.RC1 release.
"""
reshape_cache_event.synchronize()
ret = self.engine.batch_transfer_sync_write(
session_id, src_list, dst_list, length_list)
if ret < 0:
logger.error("Mooncake transfer failed for request %s", req_id)
raise RuntimeError(f"Mooncake transfer failed, ret: {ret}")
else:
key = key.view(-1, key.shape[-1])
value = value.view(-1, key.shape[-1])
self.k_buffer[:key.shape[0]].copy_(key) # [:4, 128] ->
self.v_buffer[:value.shape[0]].copy_(value)
rearrange_block_ids = send_task.rearrange_block_ids
rearrange_block_dict = {
value: index
for index, value in enumerate(
rearrange_block_ids) # type:ignore
}
layer_local_kv_base_addr = [
self.k_buffer.data_ptr(),
self.v_buffer.data_ptr()
]
layer_remote_kv_base_addr = [
remote_kv_base_addrs[i]
for i in [2 * layer_index, 2 * layer_index + 1]
remote_kv_base_addrs[i] # type:ignore
for i in [2 * layer_idx, 2 * layer_idx + 1]
]
grouped_remote_block_ids, _ = group_concurrent_contiguous(
remote_block_ids)
session_id = f"{remote_host}:{remote_te_port}"
src_list, dst_list, length_list = [], [], []
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)):
src_layer_addr = src_layer_base_addr
for group_remote_block_id in grouped_remote_block_ids:
block_len = self.block_len[0]
remote_block_len = self.block_len[0] * self.pd_head_ratio
src_list.append(src_layer_addr)
block_len = self.block_len[0]
remote_block_len = self.block_len[0] * self.pd_head_ratio
for remote_block_id, local_block_id in zip(
remote_block_ids, local_block_ids):
src = src_layer_base_addr + rearrange_block_dict[
local_block_id] * block_len
dst = dst_layer_base_addr + remote_block_id * remote_block_len + block_len * (
(self.tp_rank // self.num_head_replica) %
self.pd_head_ratio)
src_list.append(src)
dst_list.append(dst)
length_list.append(block_len)
return (src_list, dst_list, length_list)
if src_layer_addr + len(
group_remote_block_id
) * block_len > src_layer_base_addr + key.numel(
) * key.element_size():
length = src_layer_base_addr + key.numel(
) * key.element_size() - src_layer_addr
else:
length = len(group_remote_block_id) * block_len
length_list.append(length)
def _transfer_kv_cache(self, send_task: SendTask):
if self.pd_head_ratio > 1:
with npu_stream_switch(self.resharding_stream):
key = send_task.k_cache
value = send_task.v_cache
key = key.view(-1, key.shape[-1]) # type:ignore
value = value.view(-1, key.shape[-1]) # type:ignore
self.k_buffer[:key.shape[0]].copy_(key) # [:4, 128] ->
self.v_buffer[:value.shape[0]].copy_(value)
dst_list.append(dst_layer_base_addr +
group_remote_block_id[0] *
remote_block_len + length *
((self.tp_rank // self.num_head_replica) %
self.pd_head_ratio))
src_layer_addr += length
self.model_stream.synchronize()
ret = self.engine.batch_transfer_sync_write(
session_id, src_list, dst_list, length_list)
if ret < 0:
logger.error("Mooncake transfer failed for request %s", req_id)
raise RuntimeError(f"Mooncake transfer failed, ret: {ret}")
# Merge transmission tasks of the same session
session_meta: dict[str, TransferMeta] = {}
for req_id, req_meta in send_task.send_request.items():
session_id = f"{req_meta.remote_host}:{req_meta.remote_te_rpc_port}"
if session_id not in session_meta.keys():
session_meta[session_id] = TransferMeta(src=[],
dst=[],
length=[],
req_ids=[])
if layer_index == (self.total_layers - 1) and req_meta.chunk_finish:
self.callback_func(req_id, req_meta)
(src_list, dst_list,
length_list) = self.get_transfer_meta(send_task, req_id, req_meta)
session_meta[session_id].src.extend(src_list)
session_meta[session_id].dst.extend(dst_list)
session_meta[session_id].length.extend(length_list)
session_meta[session_id].req_ids.append(req_id)
if self.pd_head_ratio == 1:
"""
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
This issue will be fixed in CANN version 8.5.rc1.
You can manually build the master branch of the project at https://gitcode.com/cann/hixl
to resolve this issue before the 8.5.RC1 release.
"""
send_task.wait_event.synchronize() # type:ignore
elif self.pd_head_ratio > 1:
self.resharding_stream.synchronize()
for session_id, transfer_meta in session_meta.items():
if len(transfer_meta.src) > 0:
ret = self.engine.batch_transfer_sync_write(
session_id, transfer_meta.src, transfer_meta.dst,
transfer_meta.length)
if ret < 0:
logger.error(
f"Mooncake transfer failed for send requests {transfer_meta.req_ids} kv cache to {session_id}"
)
if send_task.layer_idx == (self.total_layers - 1):
for req_id in transfer_meta.req_ids:
req_meta = send_task.send_request[req_id]
if req_meta.chunk_finish:
self.callback_func(
req_id, req_meta
) # TODO Send a signal indicating transmission failure
else:
if send_task.layer_idx == (self.total_layers - 1):
for req_id in transfer_meta.req_ids:
req_meta = send_task.send_request[req_id]
if req_meta.chunk_finish:
self.callback_func(req_id, req_meta)
class KVCacheRecvingLayerThread(threading.Thread):
@@ -836,8 +852,10 @@ class MooncakeLayerwiseConnectorWorker:
self.pd_tp_ratio = get_ascend_config().pd_tp_ratio
self.pd_head_ratio = get_ascend_config().pd_head_ratio
self.num_head_replica = get_ascend_config().num_head_replica
self.resharding_stream = None
if self.pd_head_ratio > 1:
self.resharding_stream = torch.npu.Stream()
self.first_kv_cache = None
self.remote_poller = zmq.Poller() # type: ignore
self.decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata)
self.encoder = msgspec.msgpack.Encoder()
@@ -852,6 +870,8 @@ class MooncakeLayerwiseConnectorWorker:
deque)
self.remote_poller = zmq.Poller() # type: ignore
self.timeout = 1.0 # seconds
self.k_buffer: Optional[torch.Tensor] = None
self.v_buffer: Optional[torch.Tensor] = None
def _get_prefill_decode_size(self, vllm_config: VllmConfig):
# get prefill tp and dp size from extra config
@@ -874,12 +894,40 @@ class MooncakeLayerwiseConnectorWorker:
assert "dp_size" in decode_parallel_config.keys()
self._decode_dp_size = decode_parallel_config["dp_size"]
def create_kv_buffer(self, first_kv_cache):
if self.pd_head_ratio > 1:
# regesit kv buffer for tp inequal
alignment = 2 * 1024 * 1024
self.k_buffer = torch.zeros(first_kv_cache.numel() + alignment,
dtype=first_kv_cache.dtype,
device=first_kv_cache.device)
self.k_buffer = align_memory(
self.k_buffer, alignment)[:first_kv_cache.numel()].view(
-1, first_kv_cache.shape[-1])
self.v_buffer = torch.zeros(first_kv_cache.numel() + alignment,
dtype=first_kv_cache.dtype,
device=first_kv_cache.device)
self.v_buffer = align_memory(
self.v_buffer, alignment)[:first_kv_cache.numel()].view(
-1, first_kv_cache.shape[-1])
for tensor in (self.k_buffer, self.v_buffer):
assert tensor.data_ptr(
) % alignment == 0, "The address of the registered kv cache should be aligned to 2M"
ret_value = self.engine.register_memory(
tensor.data_ptr(), tensor.numel())
logger.info(
f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} {tensor.numel()} {ret_value=}"
)
if ret_value != 0:
raise RuntimeError("Mooncake memory registration failed. ")
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data."""
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
first_kv_cache = first_kv_cache_tuple[0]
self.first_kv_cache = first_kv_cache
self.create_kv_buffer(first_kv_cache)
# TODO(tms): Find a more robust way to detect and handle MLA
self.use_mla = first_kv_cache_tuple[0].size(
@@ -954,6 +1002,9 @@ class MooncakeLayerwiseConnectorWorker:
block_len=self.block_len,
decode_tp_size=self._decode_tp_size,
first_kv_cache=first_kv_cache,
k_buffer=self.k_buffer,
v_buffer=self.v_buffer,
resharding_stream=self.resharding_stream,
callback_func=self.send_done_send_signal)
self.kv_send_layer_thread.start()
ready_event.wait()
@@ -1002,60 +1053,49 @@ class MooncakeLayerwiseConnectorWorker:
reshape_cache_event = attn_metadata.reshape_cache_event
if self.pd_head_ratio != 1:
assert self.resharding_stream is not None
with npu_stream_switch(self.resharding_stream):
reshape_cache_event.wait()
rearrange_block_ids = sorted({
block_id
for request in connector_metadata.requests.values()
for block_id in request.local_block_ids
})
def sort_kv_cache(input_kv: list[list[int]]):
return torch.cat([
torch.chunk(tensor, self.pd_head_ratio, dim=0)[x]
for x in range(self.pd_head_ratio)
for tensor in input_kv
])
total_block_ids = [
request.local_block_ids
for request in connector_metadata.requests.values()
]
keys = [
kv_layer[0][block_ids].reshape(
-1, *kv_layer[0].shape[2:]).clone()
for block_ids in total_block_ids
]
values = [
kv_layer[1][block_ids].reshape(
-1, *kv_layer[1].shape[2:]).clone()
for block_ids in total_block_ids
]
key_block_size = keys[0].size(0) // len(total_block_ids[0])
value_block_size = values[0].size(0) // len(total_block_ids[0])
keys = sort_kv_cache(keys) # [req1_key, req2_key]
values = sort_kv_cache(values)
(keys,
values) = kv_alltoall_and_rearrange(self.pd_head_ratio, keys,
values)
key_start_id = 0
value_start_id = 0
keys = kv_layer[0][rearrange_block_ids].clone()
values = kv_layer[1][rearrange_block_ids].clone()
# sort kv caches for each block
keys = keys.view(keys.size(0), self.pd_head_ratio, -1,
*keys.shape[2:]).transpose(
0, 1).reshape_as(keys)
values = values.view(values.size(0), self.pd_head_ratio,
-1, *values.shape[2:]).transpose(
0, 1).reshape_as(values)
# reshard kv cache
keys = keys.reshape(-1, *kv_layer[0].shape[2:])
values = values.reshape(-1, *kv_layer[1].shape[2:])
(keys, values) = kv_alltoall_and_rearrange(
self.pd_head_ratio, keys, values)
else:
key = None
value = None
keys = None
values = None
rearrange_block_ids = None
assert self.kv_send_layer_thread is not None
assert reshape_cache_event is not None
send_task = SendTask(wait_event=reshape_cache_event,
k_cache=keys,
v_cache=values,
layer_idx=self.current_layer,
rearrange_block_ids=rearrange_block_ids)
for req_id, req_meta in connector_metadata.requests.items():
if self.pd_head_ratio != 1:
key_block_num = len(
req_meta.local_block_ids) * key_block_size
value_block_num = len(
req_meta.local_block_ids) * value_block_size
key = keys[key_start_id:key_start_id + key_block_num]
value = values[value_start_id:value_start_id +
value_block_num]
key_start_id += key_block_num
value_start_id += value_block_num
req_meta_update = self.update_decoder_info(req_id, req_meta)
logger.debug(
f"Add request {req_id} to kv send layer thread. {req_meta_update=}"
)
assert self.kv_send_layer_thread is not None
assert reshape_cache_event is not None
self.kv_send_layer_thread.send_queue.put(
(req_id, req_meta_update, self.current_layer, key, value,
reshape_cache_event))
send_task.send_request[req_id] = req_meta_update
self.kv_send_layer_thread.send_queue.put(send_task)
self.current_layer += 1
def _get_remote_socket(
@@ -1106,6 +1146,14 @@ class MooncakeLayerwiseConnectorWorker:
logger.info(
f"Query to port and kv base addr for request {req_id} from {req_meta_update.remote_host}:{req_meta_update.remote_port} success {agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}"
)
session_id = f"{req_meta_update.remote_host}:{agent_meta.te_rpc_port}"
ret = self.engine.batch_transfer_sync_write(
session_id, [self.kv_caches_base_addr[0]],
[agent_meta.kv_caches_base_addr[0]], 128)
if ret < 0:
logger.error(
f"Mooncake transfer failed to create link to device {session_id}"
)
req_meta_update.remote_te_rpc_port = self.remote_te_port[
req_meta_update.remote_engine_id][req_meta_update.remote_port]
req_meta_update.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[