From 330e25ab1d83c3429edab25ec6fc25b9c59e2221 Mon Sep 17 00:00:00 2001 From: liziyu <56102866+liziyu179@users.noreply.github.com> Date: Tue, 6 Jan 2026 20:25:36 +0800 Subject: [PATCH] [P/D] Performance enhancement of Layerwise connector in TP asymmetric scenarios (#5540) ### What this PR does / why we need it? [P/D] Performance enhancement of Layerwise connector in TP asymmetric scenarios 1. Session fusion: For transmission tasks at each layer, aggregate transmission tasks with the same destination and merge them into a single task for assignment. 2. Alltoall aggregation: For TP asymmetric scenarios, perform all alltoall operations at once according to the block granularity for all requests. [RFC]: CDCP Scheduling for Disaggregated Prefilling with KV Cache Layerwise Push Support https://github.com/vllm-project/vllm-ascend/issues/4842 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/45c1ca1ca1ee8fa06df263c8715e8a412ff408d4 --------- Signed-off-by: liziyu Signed-off-by: nwpu-zxr Signed-off-by: wangxiaoteng Co-authored-by: nwpu-zxr Co-authored-by: wangxiaoteng --- .../test_mooncake_layerwise_connector.py | 216 +++++----- .../mooncake_layerwise_connector.py | 400 ++++++++++-------- 2 files changed, 341 insertions(+), 275 deletions(-) diff --git a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py index 6eb38454..0b55c265 100644 --- a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py +++ b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py @@ -1,3 +1,4 @@ +import contextlib import os import sys import threading @@ -18,9 +19,9 @@ from vllm_ascend.distributed.mooncake_layerwise_connector import ( # noqa: E402 KVCacheRecvingLayerThread, KVCacheSendingLayerThread, KVConnectorRole, MooncakeAgentMetadata, MooncakeLayerwiseConnector, MooncakeLayerwiseConnectorMetadata, MooncakeLayerwiseConnectorScheduler, - MooncakeLayerwiseConnectorWorker, ReqMeta, SendReqInfo, ensure_zmq_recv, - ensure_zmq_send, group_concurrent_contiguous, string_to_int64_hash, - zmq_ctx) + MooncakeLayerwiseConnectorWorker, ReqMeta, SendReqInfo, SendTask, + ensure_zmq_recv, ensure_zmq_send, group_concurrent_contiguous, + string_to_int64_hash, zmq_ctx) GET_META_MSG = b"get_meta_msg" DONE_SENDING_MSG = b"done_sending_msg" @@ -32,14 +33,8 @@ class TestKVCacheSendingLayerThread(unittest.TestCase): self.engine = MagicMock() self.engine.register_memory.return_value = 0 self.engine.batch_transfer_sync_write.return_value = 1 - self._patcher_cs = patch( - 'vllm_ascend.distributed.mooncake_layerwise_connector.torch_npu.npu.current_stream' - ) - self.mock_current_stream = self._patcher_cs.start() - self.addCleanup(self._patcher_cs.stop) fake_stream = MagicMock(name="FakeStream") fake_stream.synchronize = MagicMock() - self.mock_current_stream.return_value = fake_stream self.first_kv_cache = torch.zeros((2, 2, 2, 8), dtype=torch.float32, @@ -47,6 +42,14 @@ class TestKVCacheSendingLayerThread(unittest.TestCase): self.ready_event = threading.Event() + self.fake_k_buffer = MagicMock() + self.fake_v_buffer = MagicMock() + fake_resharding_stream = MagicMock() + cap = self.first_kv_cache.numel() // self.first_kv_cache.shape[-1] + dim = self.first_kv_cache.shape[-1] + + self.key = torch.zeros((cap, dim), dtype=torch.float32) + self.value = torch.zeros((cap, dim), dtype=torch.float32) self.thread = KVCacheSendingLayerThread( engine=self.engine, total_layers=3, @@ -60,6 +63,9 @@ class TestKVCacheSendingLayerThread(unittest.TestCase): block_len=[1024, 2048], decode_tp_size=1, first_kv_cache=self.first_kv_cache, + k_buffer=self.fake_k_buffer, + v_buffer=self.fake_v_buffer, + resharding_stream=fake_resharding_stream, callback_func=MagicMock()) self.req_meta_base = ReqMeta( @@ -74,6 +80,9 @@ class TestKVCacheSendingLayerThread(unittest.TestCase): metaserver="http://dummy", chunk_finish=False) + @patch( + "vllm_ascend.distributed.mooncake_layerwise_connector.npu_stream_switch", + side_effect=lambda *_args, **_kwargs: contextlib.nullcontext()) @patch( "vllm_ascend.distributed.mooncake_layerwise_connector.torch.Tensor.data_ptr", autospec=True, @@ -87,7 +96,10 @@ class TestKVCacheSendingLayerThread(unittest.TestCase): "vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous" ) def test_transfer_pd_gt1_uses_buffers_and_calls_engine( - self, mock_group, _mock_sync, _mock_align, _mock_dataptr): + self, mock_group, _mock_sync, _mock_align, _mock_dataptr, + mock_stream_switch): + + fake_resharding_stream = MagicMock() thread = KVCacheSendingLayerThread( engine=self.engine, @@ -101,26 +113,28 @@ class TestKVCacheSendingLayerThread(unittest.TestCase): block_len=[64], decode_tp_size=1, first_kv_cache=self.first_kv_cache, + k_buffer=self.fake_k_buffer, + v_buffer=self.fake_v_buffer, + resharding_stream=fake_resharding_stream, callback_func=MagicMock()) req_meta = self.req_meta_base req_meta.remote_kv_caches_base_addr = [4000, 8000] mock_group.return_value = ([[10, 11], [20, 21]], []) + key = torch.zeros((1, 8), dtype=torch.float32) + value = torch.zeros((1, 8), dtype=torch.float32) - cap = self.first_kv_cache.numel() // self.first_kv_cache.shape[-1] - dim = self.first_kv_cache.shape[-1] + send_task = SendTask( + send_request={"req1": req_meta}, + wait_event=MagicMock(), + k_cache=key, + v_cache=value, + layer_idx=0, + rearrange_block_ids=[5, 8], + ) - key = torch.zeros((cap, dim), dtype=torch.float32) - value = torch.zeros((cap, dim), dtype=torch.float32) - - thread._transfer_kv_cache( # type: ignore - req_id="req1", - req_meta=req_meta, - layer_index=0, - key=key, - value=value, - reshape_cache_event=MagicMock()) + thread._transfer_kv_cache(send_task) self.engine.batch_transfer_sync_write.assert_called_once() session_id, src_list, dst_list, length_list = self.engine.batch_transfer_sync_write.call_args[ @@ -145,60 +159,45 @@ class TestKVCacheSendingLayerThread(unittest.TestCase): def test_transfer_skips_when_no_local_blocks(self): req_meta = self.req_meta_base req_meta.local_block_ids = [] - self.thread.pd_head_ratio = 1 - self.thread.block_len = [64, 128] - - key = torch.zeros((1, 8), dtype=torch.float32) - value = torch.zeros((1, 8), dtype=torch.float32) - - reshape_cache_event = MagicMock() - with patch.object(self.engine, - 'batch_transfer_sync_write') as mock_batch_transfer: - mock_batch_transfer.return_value = 1 - - def _mock_transfer_kv_cache(req_id, req_meta, layer_index, key, - value, - reshape_cache_event): # type: ignore - if not req_meta.local_block_ids: - return - self._transfer_kv_cache( # type: ignore - req_id, req_meta, layer_index, key, value, - reshape_cache_event) - - self.thread._transfer_kv_cache = _mock_transfer_kv_cache # type: ignore - self.thread._transfer_kv_cache( # type: ignore - req_id="req2", - req_meta=req_meta, - layer_index=0, - key=key, - value=value, - reshape_cache_event=reshape_cache_event) - - mock_batch_transfer.assert_not_called() - self.assertEqual(mock_batch_transfer.call_count, 0) + send_task = SendTask( + send_request={"req2": req_meta}, + wait_event=MagicMock(), + k_cache=torch.zeros((1, 8)), + v_cache=torch.zeros((1, 8)), + layer_idx=0, + rearrange_block_ids=[], + ) + self.thread._transfer_kv_cache(send_task) + self.engine.batch_transfer_sync_write.assert_not_called() def test_transfer_skips_when_tp_not_sender(self): - thread = KVCacheSendingLayerThread(engine=self.engine, - total_layers=2, - ready_event=self.ready_event, - tp_rank=1, - pd_head_ratio=1, - num_head_replica=2, - kv_cache_base_addr=[1000, 2000], - use_mla=False, - block_len=[1024], - decode_tp_size=1, - first_kv_cache=self.first_kv_cache, - callback_func=MagicMock()) + thread = KVCacheSendingLayerThread( + engine=self.engine, + total_layers=2, + ready_event=self.ready_event, + tp_rank=1, + pd_head_ratio=1, + num_head_replica=2, + kv_cache_base_addr=[1000, 2000, 3000, 4000], + use_mla=False, + block_len=[1024], + decode_tp_size=1, + first_kv_cache=self.first_kv_cache, + k_buffer=MagicMock(), + v_buffer=MagicMock(), + resharding_stream=MagicMock(), + callback_func=MagicMock()) req_meta = self.req_meta_base - thread._transfer_kv_cache( # type: ignore - "req3", - req_meta, - 0, - torch.zeros((1, 8)), - torch.zeros((1, 8)), - reshape_cache_event=MagicMock()) + send_task = SendTask( + send_request={"req3": req_meta}, + wait_event=MagicMock(), + k_cache=self.key, + v_cache=self.value, + layer_idx=1, + rearrange_block_ids=[], + ) + thread._transfer_kv_cache(send_task) self.engine.batch_transfer_sync_write.assert_not_called() @patch( @@ -208,30 +207,30 @@ class TestKVCacheSendingLayerThread(unittest.TestCase): "vllm_ascend.distributed.mooncake_layerwise_connector.torch.npu.synchronize" ) def test_callback_invoked_on_final_layer(self, _mock_sync, _mock_group): + req_meta = self.req_meta_base + req_meta.chunk_finish = True req_meta.local_block_ids = [5, 6] req_meta.remote_block_ids = [10, 11] + req_meta.remote_kv_caches_base_addr = [ 7000, 8000, 9000, 10000, 11000, 12000 ] - req_meta.chunk_finish = True + key = torch.zeros((1, 8), dtype=torch.float32) value = torch.zeros((1, 8), dtype=torch.float32) - send_task = MagicMock() - send_task.layer_index = self.thread.total_layers - 1 - send_task.send_request = {"req5": req_meta} + send_task = SendTask( + send_request={"req5": req_meta}, + wait_event=MagicMock(), + k_cache=key, + v_cache=value, + layer_idx=2, + rearrange_block_ids=[], + ) + self.thread._transfer_kv_cache(send_task) - with patch.object(self.thread, 'callback_func') as mock_callback_func: - self.thread._transfer_kv_cache( # type: ignore - req_id="req5", - req_meta=req_meta, - layer_index=send_task.layer_index, - key=key, - value=value, - reshape_cache_event=MagicMock()) - print(f"Callback called: {mock_callback_func.call_count} times") - mock_callback_func.assert_called_once() + self.thread.callback_func.assert_called_once() class TestKVCacheRecvingLayerThread(unittest.TestCase): @@ -506,10 +505,10 @@ class TestMooncakeLayerwiseConnectorSchedulerMatchedTokens(unittest.TestCase): self.assertTrue(async_flag) def test_build_connector_meta(self): + self.scheduler.vllm_config.kv_transfer_config.is_kv_consumer = True request = MockRequest("req1") self.scheduler._reqs_need_recv["req1"] = (request, [], [4, 5, 6]) - self.scheduler.vllm_config.kv_transfer_config.is_kv_consumer = True request.kv_transfer_params = { "remote_block_ids": [1, 2, 3], "remote_engine_id": "remote", @@ -554,9 +553,9 @@ class _MockSchedulerOutput: new_block_ids=cached_new_block_ids or [], num_computed_tokens=cached_num_computed or [], ) + self.scheduled_spec_decode_tokens = scheduled_spec_decode_tokens or {} self.scheduled_new_reqs = new_reqs or [] self.num_scheduled_tokens = num_sched or {} - self.scheduled_spec_decode_tokens = scheduled_spec_decode_tokens or {} class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase): @@ -593,14 +592,12 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase): self.assertFalse(req.kv_transfer_params.get("do_remote_prefill", True)) def test_update_state_after_alloc_decode_records_send_layerwise(self): - req = MockRequest( - "req_u2", - prompt_token_ids=list(range(10)), - kv_transfer_params={ - "do_remote_decode": True, - "remote_block_ids": [] # 修改为空列表 [] - }) - + req = MockRequest("req_u2", + prompt_token_ids=list(range(10)), + kv_transfer_params={ + "do_remote_decode": True, + "remote_block_ids": [] + }) blocks = _MockBlocks(unhashed=[], block_ids_tuple=([7, 8, 9], )) self.scheduler.update_state_after_alloc(req, blocks, @@ -610,7 +607,24 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase): self.assertEqual(info.local_block_ids, [7, 8, 9]) self.assertIs(info.request, req) self.assertEqual(info.remote_block_ids, []) - self.assertIsInstance(info.remote_block_ids, list) + + def test_build_connector_meta_consumes_reqs_need_recv_and_clears(self): + self.scheduler.vllm_config.kv_transfer_config.is_kv_consumer = True + req = MockRequest("req_b1", + kv_transfer_params={ + "remote_block_ids": [1, 2], + "remote_engine_id": "E", + "remote_host": "H", + "remote_port": 5555, + "remote_te_rpc_port": 6000, + "remote_kv_caches_base_addr": [10, 11], + }) + self.scheduler._reqs_need_recv["req_b1"] = (req, [], [100, 101]) + meta = self.scheduler.build_connector_meta(_MockSchedulerOutput()) + self.assertIsInstance(meta, MooncakeLayerwiseConnectorMetadata) + self.assertIn("req_b1", meta.requests) + self.assertEqual(meta.requests["req_b1"].local_block_ids, [100, 101]) + self.assertEqual(len(self.scheduler._reqs_need_recv), 0) def test_build_connector_meta_accumulates_cached_blocks(self): req_meta = MagicMock(spec=ReqMeta) @@ -637,8 +651,6 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase): meta = self.scheduler.build_connector_meta(out) self.assertEqual(len(meta.requests), 0) - req_meta.extend_local_block_ids.assert_called_once_with([3, 4]) - @patch( "vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous" ) @@ -684,6 +696,12 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase): send_req_info.extend_local_block_ids.assert_called_once_with([50]) self.assertIn("req_b3", meta.requests) + def test_request_finished_returns_false_none(self): + ok, params = self.scheduler.request_finished(MockRequest("req_fin"), + [1, 2]) + self.assertFalse(ok) + self.assertIsNone(params) + class TestHelperFunctions(unittest.TestCase): diff --git a/vllm_ascend/distributed/mooncake_layerwise_connector.py b/vllm_ascend/distributed/mooncake_layerwise_connector.py index 9d9d9301..edbba3f7 100644 --- a/vllm_ascend/distributed/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/mooncake_layerwise_connector.py @@ -11,7 +11,7 @@ import time from collections import OrderedDict, defaultdict, deque from collections.abc import Iterator from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple import httpx @@ -19,7 +19,6 @@ import msgspec import numpy as np import numpy.typing as npt import torch -import torch_npu import zmq from mooncake.engine import TransferEngine # type: ignore from vllm.config import VllmConfig @@ -37,6 +36,7 @@ from vllm_ascend.distributed.mooncake_transfer_engine import global_te from vllm_ascend.distributed.utils import (align_memory, get_transfer_timeout_value, kv_alltoall_and_rearrange) +from vllm_ascend.utils import npu_stream_switch if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -56,7 +56,7 @@ class MooncakeAgentMetadata(msgspec.Struct, omit_defaults=True, dict=True): @dataclass class ReqMeta: local_block_ids: list[int] - token_ids: list[int] + token_ids: Optional[list[int]] # Not None if layer-wise is disabled remote_block_ids: list[int] remote_engine_id: Optional[str] @@ -68,6 +68,26 @@ class ReqMeta: chunk_finish: Optional[bool] +@dataclass +class SendTask: + send_request: dict[str, ReqMeta] = field(default_factory=dict) + # pd_head_ratio == 1 use + wait_event: Optional[torch.npu.Event] = None + # pd_head_ratio > 1 use + k_cache: Optional[torch.Tensor] = None + v_cache: Optional[torch.Tensor] = None + layer_idx: int = 0 + rearrange_block_ids: Optional[list[int]] = None + + +@dataclass +class TransferMeta: + src: list[int] + dst: list[int] + length: list[int] + req_ids: list[str] + + @dataclass class SendReqInfo: local_block_ids: list[int] @@ -116,19 +136,24 @@ class SizedDict(OrderedDict): class KVCacheSendingLayerThread(threading.Thread): - def __init__(self, - engine: TransferEngine, - total_layers: int, - ready_event: threading.Event, - tp_rank: int, - pd_head_ratio: int, - num_head_replica: int, - kv_cache_base_addr: list[int], - use_mla: bool, - block_len: list[int], - decode_tp_size: int, - first_kv_cache: torch.Tensor, - callback_func: Callable[..., None] = lambda x: None): + def __init__( + self, + engine: TransferEngine, + total_layers: int, + ready_event: threading.Event, + tp_rank: int, + pd_head_ratio: int, + num_head_replica: int, + kv_cache_base_addr: list[int], + use_mla: bool, + block_len: list[int], + decode_tp_size: int, + first_kv_cache: torch.Tensor, + k_buffer: torch.Tensor, + v_buffer: torch.Tensor, + resharding_stream: torch.npu.Stream, + callback_func: Callable[..., None] = lambda x: None, + ): super().__init__(daemon=True, name="KVCacheSendingLayerThread") self.engine = engine self.tp_rank = tp_rank @@ -139,39 +164,12 @@ class KVCacheSendingLayerThread(threading.Thread): self.use_mla = use_mla self.block_len = block_len self._decode_tp_size = decode_tp_size - self.model_stream = torch_npu.npu.current_stream() + self.resharding_stream = resharding_stream self.current_layer = -1 - if self.pd_head_ratio > 1: - # regesit kv buffer for tp inequal - alignment = 2 * 1024 * 1024 - self.k_buffer = torch.zeros(first_kv_cache.numel() + alignment, - dtype=first_kv_cache.dtype, - device=first_kv_cache.device) - self.k_buffer = align_memory( - self.k_buffer, alignment)[:first_kv_cache.numel()].view( - -1, first_kv_cache.shape[-1]) - self.v_buffer = torch.zeros(first_kv_cache.numel() + alignment, - dtype=first_kv_cache.dtype, - device=first_kv_cache.device) - self.v_buffer = align_memory( - self.v_buffer, alignment)[:first_kv_cache.numel()].view( - -1, first_kv_cache.shape[-1]) - - for tensor in (self.k_buffer, self.v_buffer): - assert tensor.data_ptr( - ) % alignment == 0, "The address of the registered kv cache should be aligned to 2M" - ret_value = self.engine.register_memory( - tensor.data_ptr(), tensor.numel()) - logger.info( - f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} {tensor.numel()} {ret_value=}" - ) - if ret_value != 0: - raise RuntimeError("Mooncake memory registration failed. ") - - self.send_queue = queue.Queue[Tuple[str, ReqMeta, int, torch.Tensor, - torch.Tensor, torch.npu.Event]]() - + self.send_queue = queue.Queue[SendTask]() + self.k_buffer = k_buffer + self.v_buffer = v_buffer self.ready_event = ready_event self.callback_func = callback_func @@ -181,43 +179,36 @@ class KVCacheSendingLayerThread(threading.Thread): torch.npu.set_device(device) self.ready_event.set() while True: - req_id, req_meta, layer_index, key, value, reshape_cache_event = self.send_queue.get( - ) - self._handle_request(req_id, req_meta, layer_index, key, value, - reshape_cache_event) + send_task = self.send_queue.get() + self._handle_request(send_task) - def _handle_request(self, req_id, req_meta, layer_index, key, value, - reshape_cache_event): + def _handle_request(self, send_task: SendTask): try: - logger.debug( - f"Starting to transfer KV cache for request {req_id} {req_meta.remote_te_rpc_port=}." - ) - self._transfer_kv_cache(req_id, req_meta, layer_index, key, value, - reshape_cache_event) - logger.debug( - f"Finished transferring KV cache for request {req_id} {req_meta.remote_te_rpc_port=}." - ) + self._transfer_kv_cache(send_task) except Exception as e: - logger.error("Failed to transfer KV cache for request " - f"{req_id}: {e}") + logger.error( + f"Failed to transfer KV cache for layer idx {send_task.layer_idx}, {e}" + ) - def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value, - reshape_cache_event): + def get_transfer_meta(self, send_task: SendTask, req_id: str, + req_meta: ReqMeta): + src_list: list[str] = [] + dst_list: list[str] = [] + length_list: list[int] = [] # not need to send kv cache if self.tp_rank % self.num_head_replica != 0: logger.debug( f"Cancelling KV cache transfer for request {req_id}. Reason: TP rank excluded from head replication (TP Rank: {self.tp_rank}, Replicas: {self.num_head_replica})." ) - return + return (src_list, dst_list, length_list) if self.use_mla and self.tp_rank >= self._decode_tp_size: logger.debug( f"Cancelling KV cache transfer for request {req_id}. Reason: MLA mode active and TP rank outside decoding group (TP Rank: {self.tp_rank}, Decode TP Size: {self._decode_tp_size})." ) - return + return (src_list, dst_list, length_list) - remote_host = req_meta.remote_host + layer_idx = send_task.layer_idx remote_block_ids = req_meta.remote_block_ids - remote_te_port = req_meta.remote_te_rpc_port remote_kv_base_addrs = req_meta.remote_kv_caches_base_addr local_kv_base_addr = self.kv_caches_base_addr local_block_ids = req_meta.local_block_ids @@ -225,17 +216,15 @@ class KVCacheSendingLayerThread(threading.Thread): if self.pd_head_ratio == 1: layer_local_kv_base_addr = [ local_kv_base_addr[i] - for i in [2 * layer_index, 2 * layer_index + 1] + for i in [2 * layer_idx, 2 * layer_idx + 1] ] layer_remote_kv_base_addr = [ - remote_kv_base_addrs[i] - for i in [2 * layer_index, 2 * layer_index + 1] + remote_kv_base_addrs[i] # type:ignore + for i in [2 * layer_idx, 2 * layer_idx + 1] ] grouped_remote_block_ids, grouped_local_block_ids = \ group_concurrent_contiguous(remote_block_ids, local_block_ids) - session_id = f"{remote_host}:{remote_te_port}" - src_list, dst_list, length_list = [], [], [] for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)): block_len = self.block_len[ @@ -250,74 +239,101 @@ class KVCacheSendingLayerThread(threading.Thread): src_list.append(src) dst_list.append(dst) length_list.append(length) - if self.current_layer != layer_index: - self.current_layer = layer_index - """ - Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang. - This issue will be fixed in CANN version 8.5.rc1. - You can manually build the master branch of the project at https://gitcode.com/cann/hixl - to resolve this issue before the 8.5.RC1 release. - """ - reshape_cache_event.synchronize() - ret = self.engine.batch_transfer_sync_write( - session_id, src_list, dst_list, length_list) - if ret < 0: - logger.error("Mooncake transfer failed for request %s", req_id) - raise RuntimeError(f"Mooncake transfer failed, ret: {ret}") else: - key = key.view(-1, key.shape[-1]) - value = value.view(-1, key.shape[-1]) - self.k_buffer[:key.shape[0]].copy_(key) # [:4, 128] -> - self.v_buffer[:value.shape[0]].copy_(value) - + rearrange_block_ids = send_task.rearrange_block_ids + rearrange_block_dict = { + value: index + for index, value in enumerate( + rearrange_block_ids) # type:ignore + } layer_local_kv_base_addr = [ self.k_buffer.data_ptr(), self.v_buffer.data_ptr() ] layer_remote_kv_base_addr = [ - remote_kv_base_addrs[i] - for i in [2 * layer_index, 2 * layer_index + 1] + remote_kv_base_addrs[i] # type:ignore + for i in [2 * layer_idx, 2 * layer_idx + 1] ] - grouped_remote_block_ids, _ = group_concurrent_contiguous( - remote_block_ids) - - session_id = f"{remote_host}:{remote_te_port}" src_list, dst_list, length_list = [], [], [] for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)): - src_layer_addr = src_layer_base_addr - for group_remote_block_id in grouped_remote_block_ids: - block_len = self.block_len[0] - remote_block_len = self.block_len[0] * self.pd_head_ratio - src_list.append(src_layer_addr) + block_len = self.block_len[0] + remote_block_len = self.block_len[0] * self.pd_head_ratio + for remote_block_id, local_block_id in zip( + remote_block_ids, local_block_ids): + src = src_layer_base_addr + rearrange_block_dict[ + local_block_id] * block_len + dst = dst_layer_base_addr + remote_block_id * remote_block_len + block_len * ( + (self.tp_rank // self.num_head_replica) % + self.pd_head_ratio) + src_list.append(src) + dst_list.append(dst) + length_list.append(block_len) + return (src_list, dst_list, length_list) - if src_layer_addr + len( - group_remote_block_id - ) * block_len > src_layer_base_addr + key.numel( - ) * key.element_size(): - length = src_layer_base_addr + key.numel( - ) * key.element_size() - src_layer_addr - else: - length = len(group_remote_block_id) * block_len - length_list.append(length) + def _transfer_kv_cache(self, send_task: SendTask): + if self.pd_head_ratio > 1: + with npu_stream_switch(self.resharding_stream): + key = send_task.k_cache + value = send_task.v_cache + key = key.view(-1, key.shape[-1]) # type:ignore + value = value.view(-1, key.shape[-1]) # type:ignore + self.k_buffer[:key.shape[0]].copy_(key) # [:4, 128] -> + self.v_buffer[:value.shape[0]].copy_(value) - dst_list.append(dst_layer_base_addr + - group_remote_block_id[0] * - remote_block_len + length * - ((self.tp_rank // self.num_head_replica) % - self.pd_head_ratio)) - src_layer_addr += length - self.model_stream.synchronize() - ret = self.engine.batch_transfer_sync_write( - session_id, src_list, dst_list, length_list) - if ret < 0: - logger.error("Mooncake transfer failed for request %s", req_id) - raise RuntimeError(f"Mooncake transfer failed, ret: {ret}") + # Merge transmission tasks of the same session + session_meta: dict[str, TransferMeta] = {} + for req_id, req_meta in send_task.send_request.items(): + session_id = f"{req_meta.remote_host}:{req_meta.remote_te_rpc_port}" + if session_id not in session_meta.keys(): + session_meta[session_id] = TransferMeta(src=[], + dst=[], + length=[], + req_ids=[]) - if layer_index == (self.total_layers - 1) and req_meta.chunk_finish: - self.callback_func(req_id, req_meta) + (src_list, dst_list, + length_list) = self.get_transfer_meta(send_task, req_id, req_meta) + + session_meta[session_id].src.extend(src_list) + session_meta[session_id].dst.extend(dst_list) + session_meta[session_id].length.extend(length_list) + session_meta[session_id].req_ids.append(req_id) + + if self.pd_head_ratio == 1: + """ + Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang. + This issue will be fixed in CANN version 8.5.rc1. + You can manually build the master branch of the project at https://gitcode.com/cann/hixl + to resolve this issue before the 8.5.RC1 release. + """ + send_task.wait_event.synchronize() # type:ignore + elif self.pd_head_ratio > 1: + self.resharding_stream.synchronize() + + for session_id, transfer_meta in session_meta.items(): + if len(transfer_meta.src) > 0: + ret = self.engine.batch_transfer_sync_write( + session_id, transfer_meta.src, transfer_meta.dst, + transfer_meta.length) + if ret < 0: + logger.error( + f"Mooncake transfer failed for send requests {transfer_meta.req_ids} kv cache to {session_id}" + ) + if send_task.layer_idx == (self.total_layers - 1): + for req_id in transfer_meta.req_ids: + req_meta = send_task.send_request[req_id] + if req_meta.chunk_finish: + self.callback_func( + req_id, req_meta + ) # TODO Send a signal indicating transmission failure + else: + if send_task.layer_idx == (self.total_layers - 1): + for req_id in transfer_meta.req_ids: + req_meta = send_task.send_request[req_id] + if req_meta.chunk_finish: + self.callback_func(req_id, req_meta) class KVCacheRecvingLayerThread(threading.Thread): @@ -836,8 +852,10 @@ class MooncakeLayerwiseConnectorWorker: self.pd_tp_ratio = get_ascend_config().pd_tp_ratio self.pd_head_ratio = get_ascend_config().pd_head_ratio self.num_head_replica = get_ascend_config().num_head_replica + self.resharding_stream = None + if self.pd_head_ratio > 1: + self.resharding_stream = torch.npu.Stream() - self.first_kv_cache = None self.remote_poller = zmq.Poller() # type: ignore self.decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata) self.encoder = msgspec.msgpack.Encoder() @@ -852,6 +870,8 @@ class MooncakeLayerwiseConnectorWorker: deque) self.remote_poller = zmq.Poller() # type: ignore self.timeout = 1.0 # seconds + self.k_buffer: Optional[torch.Tensor] = None + self.v_buffer: Optional[torch.Tensor] = None def _get_prefill_decode_size(self, vllm_config: VllmConfig): # get prefill tp and dp size from extra config @@ -874,12 +894,40 @@ class MooncakeLayerwiseConnectorWorker: assert "dp_size" in decode_parallel_config.keys() self._decode_dp_size = decode_parallel_config["dp_size"] + def create_kv_buffer(self, first_kv_cache): + if self.pd_head_ratio > 1: + # regesit kv buffer for tp inequal + alignment = 2 * 1024 * 1024 + self.k_buffer = torch.zeros(first_kv_cache.numel() + alignment, + dtype=first_kv_cache.dtype, + device=first_kv_cache.device) + self.k_buffer = align_memory( + self.k_buffer, alignment)[:first_kv_cache.numel()].view( + -1, first_kv_cache.shape[-1]) + self.v_buffer = torch.zeros(first_kv_cache.numel() + alignment, + dtype=first_kv_cache.dtype, + device=first_kv_cache.device) + self.v_buffer = align_memory( + self.v_buffer, alignment)[:first_kv_cache.numel()].view( + -1, first_kv_cache.shape[-1]) + + for tensor in (self.k_buffer, self.v_buffer): + assert tensor.data_ptr( + ) % alignment == 0, "The address of the registered kv cache should be aligned to 2M" + ret_value = self.engine.register_memory( + tensor.data_ptr(), tensor.numel()) + logger.info( + f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} {tensor.numel()} {ret_value=}" + ) + if ret_value != 0: + raise RuntimeError("Mooncake memory registration failed. ") + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data.""" _, first_kv_cache_tuple = next(iter(kv_caches.items())) first_kv_cache = first_kv_cache_tuple[0] - self.first_kv_cache = first_kv_cache + self.create_kv_buffer(first_kv_cache) # TODO(tms): Find a more robust way to detect and handle MLA self.use_mla = first_kv_cache_tuple[0].size( @@ -954,6 +1002,9 @@ class MooncakeLayerwiseConnectorWorker: block_len=self.block_len, decode_tp_size=self._decode_tp_size, first_kv_cache=first_kv_cache, + k_buffer=self.k_buffer, + v_buffer=self.v_buffer, + resharding_stream=self.resharding_stream, callback_func=self.send_done_send_signal) self.kv_send_layer_thread.start() ready_event.wait() @@ -1002,60 +1053,49 @@ class MooncakeLayerwiseConnectorWorker: reshape_cache_event = attn_metadata.reshape_cache_event if self.pd_head_ratio != 1: + assert self.resharding_stream is not None + with npu_stream_switch(self.resharding_stream): + reshape_cache_event.wait() + rearrange_block_ids = sorted({ + block_id + for request in connector_metadata.requests.values() + for block_id in request.local_block_ids + }) - def sort_kv_cache(input_kv: list[list[int]]): - return torch.cat([ - torch.chunk(tensor, self.pd_head_ratio, dim=0)[x] - for x in range(self.pd_head_ratio) - for tensor in input_kv - ]) - - total_block_ids = [ - request.local_block_ids - for request in connector_metadata.requests.values() - ] - keys = [ - kv_layer[0][block_ids].reshape( - -1, *kv_layer[0].shape[2:]).clone() - for block_ids in total_block_ids - ] - values = [ - kv_layer[1][block_ids].reshape( - -1, *kv_layer[1].shape[2:]).clone() - for block_ids in total_block_ids - ] - key_block_size = keys[0].size(0) // len(total_block_ids[0]) - value_block_size = values[0].size(0) // len(total_block_ids[0]) - keys = sort_kv_cache(keys) # [req1_key, req2_key] - values = sort_kv_cache(values) - (keys, - values) = kv_alltoall_and_rearrange(self.pd_head_ratio, keys, - values) - key_start_id = 0 - value_start_id = 0 + keys = kv_layer[0][rearrange_block_ids].clone() + values = kv_layer[1][rearrange_block_ids].clone() + # sort kv caches for each block + keys = keys.view(keys.size(0), self.pd_head_ratio, -1, + *keys.shape[2:]).transpose( + 0, 1).reshape_as(keys) + values = values.view(values.size(0), self.pd_head_ratio, + -1, *values.shape[2:]).transpose( + 0, 1).reshape_as(values) + # reshard kv cache + keys = keys.reshape(-1, *kv_layer[0].shape[2:]) + values = values.reshape(-1, *kv_layer[1].shape[2:]) + (keys, values) = kv_alltoall_and_rearrange( + self.pd_head_ratio, keys, values) else: - key = None - value = None + keys = None + values = None + rearrange_block_ids = None + + assert self.kv_send_layer_thread is not None + assert reshape_cache_event is not None + send_task = SendTask(wait_event=reshape_cache_event, + k_cache=keys, + v_cache=values, + layer_idx=self.current_layer, + rearrange_block_ids=rearrange_block_ids) for req_id, req_meta in connector_metadata.requests.items(): - if self.pd_head_ratio != 1: - key_block_num = len( - req_meta.local_block_ids) * key_block_size - value_block_num = len( - req_meta.local_block_ids) * value_block_size - key = keys[key_start_id:key_start_id + key_block_num] - value = values[value_start_id:value_start_id + - value_block_num] - key_start_id += key_block_num - value_start_id += value_block_num req_meta_update = self.update_decoder_info(req_id, req_meta) logger.debug( f"Add request {req_id} to kv send layer thread. {req_meta_update=}" ) - assert self.kv_send_layer_thread is not None - assert reshape_cache_event is not None - self.kv_send_layer_thread.send_queue.put( - (req_id, req_meta_update, self.current_layer, key, value, - reshape_cache_event)) + send_task.send_request[req_id] = req_meta_update + + self.kv_send_layer_thread.send_queue.put(send_task) self.current_layer += 1 def _get_remote_socket( @@ -1106,6 +1146,14 @@ class MooncakeLayerwiseConnectorWorker: logger.info( f"Query to port and kv base addr for request {req_id} from {req_meta_update.remote_host}:{req_meta_update.remote_port} success {agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}" ) + session_id = f"{req_meta_update.remote_host}:{agent_meta.te_rpc_port}" + ret = self.engine.batch_transfer_sync_write( + session_id, [self.kv_caches_base_addr[0]], + [agent_meta.kv_caches_base_addr[0]], 128) + if ret < 0: + logger.error( + f"Mooncake transfer failed to create link to device {session_id}" + ) req_meta_update.remote_te_rpc_port = self.remote_te_port[ req_meta_update.remote_engine_id][req_meta_update.remote_port] req_meta_update.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[