From b881fab41687f159cce951291c583c1949c107a8 Mon Sep 17 00:00:00 2001 From: wangxiaoteng888 <56506195+wangxiaoteng888@users.noreply.github.com> Date: Thu, 12 Feb 2026 11:02:25 +0800 Subject: [PATCH] [P/D][PCP] mooncake layerwise support pcp function (#6627) ### What this PR does / why we need it? mooncake layerwise support pcp function PCP (Prefill Context Parallelism) Support: Introduced explicit support for Prefill Context Parallelism (PCP) and Decode Context Parallelism (DCP) in the Mooncake layerwise KV cache transfer mechanism, allowing for more granular control and awareness of parallel configurations during data transfer. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? By ci - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d7e17aaacd5ed1b4b4be6bcfef3a1b7cbc84fc9a --------- Signed-off-by: wangxiaoteng Signed-off-by: liziyu Co-authored-by: liziyu --- tests/ut/attention/test_attention_cp.py | 4 +- .../test_mooncake_layerwise_connector.py | 81 ++-- tests/ut/kv_connector/utils.py | 77 ++-- .../context_parallel/attention_cp.py | 5 +- .../attention/context_parallel/mla_cp.py | 4 + .../kv_p2p/mooncake_layerwise_connector.py | 360 +++++++++++------- .../distributed/kv_transfer/utils/utils.py | 243 ++++++++++++ 7 files changed, 551 insertions(+), 223 deletions(-) diff --git a/tests/ut/attention/test_attention_cp.py b/tests/ut/attention/test_attention_cp.py index 487d4169..3cdbeb6b 100644 --- a/tests/ut/attention/test_attention_cp.py +++ b/tests/ut/attention/test_attention_cp.py @@ -232,15 +232,17 @@ class TestAscendAttentionCPImpl(TestBase): self.assertEqual(value.shape[1], num_heads) self.assertEqual(value.shape[2], head_size) + @patch('torch_npu.Event', create=True) @patch('torch_npu._npu_reshape_and_cache') @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) - def test_reshape_and_cache(self, mock_npu_reshape_and_cache): + def test_reshape_and_cache(self, mock_event_class, mock_npu_reshape_and_cache): num_tokens = 4 block_num = 100 block_size = 128 num_heads = 1 head_size = 128 self.impl.head_size = head_size + self.impl.is_kv_producer = False kv_cache = (torch.randn(block_num, block_size, num_heads, head_size), torch.randn(block_num, block_size, num_heads, head_size)) diff --git a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py index f8266580..bdb1b02f 100644 --- a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py +++ b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py @@ -61,8 +61,6 @@ class TestKVCacheSendingLayerThread(unittest.TestCase): 6000], # 2 * total_layers use_mla=True, block_len=[1024, 2048], - decode_tp_size=1, - first_kv_cache=self.first_kv_cache, k_buffer=self.fake_k_buffer, v_buffer=self.fake_v_buffer, resharding_stream=fake_resharding_stream, @@ -70,6 +68,9 @@ class TestKVCacheSendingLayerThread(unittest.TestCase): self.req_meta_base = ReqMeta( local_block_ids=[5, 8], + remote_tp_size = 8, + remote_pcp_size = 1, + remote_dcp_size = 1, token_ids=[1, 2, 3], remote_block_ids=[10, 20], remote_engine_id="remote_engine", @@ -112,8 +113,6 @@ class TestKVCacheSendingLayerThread(unittest.TestCase): kv_cache_base_addr=[1111, 2222, 3333, 4444], use_mla=False, block_len=[64], - decode_tp_size=1, - first_kv_cache=self.first_kv_cache, k_buffer=self.fake_k_buffer, v_buffer=self.fake_v_buffer, resharding_stream=fake_resharding_stream, @@ -242,12 +241,12 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase): th.task_tracker["reqX"] = 0 th.request_map["reqX"] = "reqX" - th.update_task("reqX") + th.update_task("reqX", 2) with th.lock: self.assertIn("reqX", th.task_tracker) self.assertNotIn("reqX", th.done_requests) - th.update_task("reqX") + th.update_task("reqX", 2) with th.lock: self.assertNotIn("reqX", th.task_tracker) self.assertIn("reqX", th.done_requests) @@ -284,7 +283,7 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase): dec_inst = MagicMock() dec_inst.decode.side_effect = [ (GET_META_MSG, ), - (DONE_SENDING_MSG, "reqA"), + (DONE_SENDING_MSG, "reqA", 1), (b"weird_msg", ), ] mock_Decoder.return_value = dec_inst @@ -339,21 +338,11 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase): finished = th.get_and_clear_finished_requests() self.assertIn("reqA", finished) - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger" - ) - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.get_ip", - return_value="127.0.0.1") - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Decoder" - ) - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Encoder" - ) - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.zmq_ctx" - ) + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger") + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.get_ip", return_value="127.0.0.1") + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Decoder") + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Encoder") + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.zmq_ctx") def test_run_loop_pd_head_ratio_gt1_requires_multiple_done( self, mock_zmq_ctx, mock_Encoder, mock_Decoder, _mock_get_ip, _mock_logger): @@ -364,8 +353,8 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase): dec_inst = MagicMock() dec_inst.decode.side_effect = [ - (DONE_SENDING_MSG, "reqB"), - (DONE_SENDING_MSG, "reqB"), + (DONE_SENDING_MSG, "reqB", 2), + (DONE_SENDING_MSG, "reqB", 2), ] mock_Decoder.return_value = dec_inst @@ -373,25 +362,26 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase): sock.recv_multipart.side_effect = [ [b"ID", b"PAY1"], [b"ID", b"PAY2"], - SystemExit, + SystemExit, # 退出循环 ] cm = MagicMock() cm.__enter__.return_value = sock mock_zmq_ctx.return_value = cm - th = KVCacheRecvingLayerThread(tp_rank=0, - side_channel_port=5555, - tp_size=2, - pd_head_ratio=2, - local_engine_id="engineY", - metadata=self.meta, - ready_event=self.ready_event) + th = KVCacheRecvingLayerThread( + tp_rank=0, + side_channel_port=5555, + tp_size=2, + pd_head_ratio=2, + local_engine_id="engineY", + metadata=self.meta, + ready_event=self.ready_event + ) with th.lock: th.task_tracker["reqB"] = 0 th.request_map["reqB"] = "reqB" with self.assertRaises(SystemExit): th.run() - finished = th.get_and_clear_finished_requests() self.assertIn("reqB", finished) @@ -441,6 +431,7 @@ class MockRequest: self.kv_transfer_params = kv_transfer_params or {} self.status = status or "running" self.output_token_ids = [101, 102] + self.num_computed_tokens = 0 self.all_token_ids = list(self.prompt_token_ids) @@ -565,7 +556,8 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase): req = MockRequest("req_u1", prompt_token_ids=list(range(24)), kv_transfer_params={"do_remote_prefill": True}) - blocks = _MockBlocks(unhashed=[4, 5, 6]) + req.num_computed_tokens = 0 + blocks = _MockBlocks(unhashed=[4, 5, 6], block_ids_tuple=([4, 5, 6], )) self.scheduler.update_state_after_alloc(req, blocks, @@ -592,7 +584,6 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase): info = self.scheduler._reqs_need_send_layerwise["req_u2"] self.assertEqual(info.local_block_ids, [7, 8, 9]) self.assertIs(info.request, req) - self.assertEqual(info.remote_block_ids, []) def test_build_connector_meta_consumes_reqs_need_recv_and_clears(self): self.scheduler.vllm_config.kv_transfer_config.is_kv_consumer = True @@ -663,12 +654,13 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase): send_req_info.update_computed_tokens = MagicMock() send_req_info.update_transferred_tokens = MagicMock() send_req_info.unpack = MagicMock( - return_value=(send_req_info.local_block_ids, - send_req_info.remote_block_ids, - send_req_info.remote_cache_tokens, - send_req_info.local_transferred_tokens, - send_req_info.local_computed_tokens, - send_req_info.request)) + return_value=( + send_req_info.local_block_ids, + send_req_info.local_transferred_tokens, + send_req_info.local_computed_tokens, + send_req_info.request + ) + ) self.scheduler._reqs_need_send_layerwise["req_b3"] = send_req_info out = _MockSchedulerOutput( @@ -920,6 +912,11 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase): self.vllm_config = MockVllmConfig() self.engine_id = "test_engine" self.kv_caches = {"layer1": (MagicMock(), MagicMock())} + self.vllm_config.parallel_config.tensor_parallel_size = 1 + self.vllm_config.parallel_config.prefill_context_parallel_size = 1 + self.vllm_config.parallel_config.decode_context_parallel_size = 1 + self.vllm_config.parallel_config.data_parallel_rank = 0 + self.vllm_config.kv_transfer_config.kv_port = 1234 def tearDown(self): for p in self.patches: @@ -956,4 +953,4 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase): self.engine_id) worker.register_kv_caches(mla_caches) self.assertTrue(worker.use_mla) - self.assertEqual(len(worker.block_len), 2) + self.assertEqual(len(worker.block_len), 2) \ No newline at end of file diff --git a/tests/ut/kv_connector/utils.py b/tests/ut/kv_connector/utils.py index 582830d8..6a560e80 100644 --- a/tests/ut/kv_connector/utils.py +++ b/tests/ut/kv_connector/utils.py @@ -8,14 +8,11 @@ from typing import Any, Optional import torch from vllm import SamplingParams -from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, - ModelConfig, SchedulerConfig, VllmConfig) +from vllm.config import CacheConfig, DeviceConfig, KVTransferConfig, ModelConfig, SchedulerConfig, VllmConfig from vllm.utils.hashing import sha256 -from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, - init_none_hash) +from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash from vllm.v1.core.sched.scheduler import Scheduler -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec) +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request from vllm.v1.structured_output import StructuredOutputManager @@ -37,14 +34,10 @@ def assert_scheduler_empty(scheduler: Scheduler): assert len(scheduler.encoder_cache_manager.cached) == 0 # KVCache Manager. - assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. - req_to_blocks) == 0 - assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. - num_cached_block) == 0 - num_free_blocks = ( - scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) - assert num_free_blocks == ( - scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) + assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks) == 0 + assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].num_cached_block) == 0 + num_free_blocks = scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks + assert num_free_blocks == (scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) # NOTE(rob): just the ref count on blocks will be 0. The hash # value, etc will remain since we lazily evict for prefix cache. @@ -63,8 +56,7 @@ def create_vllm_config( max_num_batched_tokens=max_num_batched_tokens, max_model_len=max_num_batched_tokens, ) - fake_weight_path = os.path.join(os.path.dirname(__file__), "..", - "fake_weight") + fake_weight_path = os.path.join(os.path.dirname(__file__), "..", "fake_weight") model_config = ModelConfig( model=fake_weight_path, skip_tokenizer_init=True, @@ -77,14 +69,14 @@ def create_vllm_config( cache_dtype="auto", enable_prefix_caching=True, ) - kv_transfer_config = KVTransferConfig( - kv_connector="MooncakeConnectorV1", - kv_role="kv_both") - return VllmConfig(scheduler_config=scheduler_config, - model_config=model_config, - cache_config=cache_config, - kv_transfer_config=kv_transfer_config, - device_config=DeviceConfig("cpu")) + kv_transfer_config = KVTransferConfig(kv_connector="MooncakeConnectorV1", kv_role="kv_both") + return VllmConfig( + scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + kv_transfer_config=kv_transfer_config, + device_config=DeviceConfig("cpu"), + ) def create_scheduler( @@ -96,11 +88,7 @@ def create_scheduler( kv_cache_config = KVCacheConfig( num_blocks=num_blocks, # A large number of blocks to hold all requests kv_cache_tensors=[], - kv_cache_groups=[ - KVCacheGroupSpec(['layer'], - FullAttentionSpec(block_size, 1, 1, torch.float16, - False, False)) - ], + kv_cache_groups=[KVCacheGroupSpec(["layer"], FullAttentionSpec(block_size, 1, 1, torch.float16, False, False))], ) vllm_config.cache_config.num_gpu_blocks = num_blocks @@ -138,19 +126,19 @@ def create_request( if do_remote_decode: assert not do_remote_prefill - kv_transfer_params = dict(do_remote_prefill=False, - do_remote_decode=True) + kv_transfer_params = dict(do_remote_prefill=False, do_remote_decode=True) elif do_remote_prefill: - kv_transfer_params = dict(do_remote_prefill=True, - do_remote_decode=False, - remote_engine_id="my-engine-id", - remote_block_ids=list( - range(num_remote_blocks)), - remote_host="my-host", - remote_port=1234, - remote_tp_size=1, - remote_pcp_size=1, - remote_dcp_size=1) + kv_transfer_params = dict( + do_remote_prefill=True, + do_remote_decode=False, + remote_engine_id="my-engine-id", + remote_block_ids=list(range(num_remote_blocks)), + remote_host="my-host", + remote_port=1234, + remote_tp_size=1, + remote_pcp_size=1, + remote_dcp_size=1, + ) max_tokens = 1 if do_remote_decode else max_tokens sampling_params = SamplingParams(max_tokens=max_tokens) @@ -190,10 +178,9 @@ def create_model_runner_output( # Make output data structure. extra_args = {} - from vllm.v1.worker.kv_connector_model_runner_mixin import \ - KVConnectorOutput # type: ignore # noqa - kv_connector_output = KVConnectorOutput(finished_sending=finished_sending, - finished_recving=finished_recving) + from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput # type: ignore # noqa + + kv_connector_output = KVConnectorOutput(finished_sending=finished_sending, finished_recving=finished_recving) extra_args = {"kv_connector_output": kv_connector_output} model_runner_output = ModelRunnerOutput( diff --git a/vllm_ascend/attention/context_parallel/attention_cp.py b/vllm_ascend/attention/context_parallel/attention_cp.py index 8bffb12f..c2f919fe 100644 --- a/vllm_ascend/attention/context_parallel/attention_cp.py +++ b/vllm_ascend/attention/context_parallel/attention_cp.py @@ -743,6 +743,8 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): has_prefill = attn_metadata.num_prefills > 0 if len(kv_cache) > 1: + if self.is_kv_producer: + attn_metadata.reshape_cache_event = torch.npu.Event() if self.key_cache is None: self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] @@ -778,7 +780,8 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): value_cache=self.value_cache, slot_indices=slot_mapping, ) - + if self.is_kv_producer: + attn_metadata.reshape_cache_event.record() return key, value def _gather_global_context_output(self, local_context_attn_output): diff --git a/vllm_ascend/attention/context_parallel/mla_cp.py b/vllm_ascend/attention/context_parallel/mla_cp.py index de1bc5f3..d30ce725 100644 --- a/vllm_ascend/attention/context_parallel/mla_cp.py +++ b/vllm_ascend/attention/context_parallel/mla_cp.py @@ -414,9 +414,13 @@ class AscendMlaCPImpl(AscendMLAImpl): kv_c_normed, k_pe = prefill_k_c_normed, prefill_k_pe prefill_k_c_normed = prefill_k_c_normed.squeeze() slot_mapping = attn_metadata.slot_mapping[self.pcp_size * num_decode_tokens :] + if self.is_kv_producer: + attn_metadata.reshape_cache_event = torch.npu.Event() torch_npu._npu_reshape_and_cache( key=kv_c_normed, value=k_pe, key_cache=kv_cache[0], value_cache=kv_cache[1], slot_indices=slot_mapping ) + if self.is_kv_producer: + attn_metadata.reshape_cache_event.record() prefill_k_nope, prefill_value = ( self.kv_b_proj(prefill_k_c_normed)[0] .view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py index 03584cd3..7ea8698e 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py @@ -23,9 +23,16 @@ import torch_npu import zmq from mooncake.engine import TransferEngine # type: ignore from vllm.config import VllmConfig +from vllm.distributed import get_pcp_group from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole -from vllm.distributed.parallel_state import get_tensor_model_parallel_rank, get_tp_group, get_world_group +from vllm.distributed.parallel_state import ( + get_decode_context_model_parallel_rank, + get_tensor_model_parallel_rank, + get_tp_group, + get_world_group, +) from vllm.logger import logger +from vllm.utils.math_utils import round_down from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig @@ -35,8 +42,13 @@ from vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector import GET_ME from vllm_ascend.distributed.kv_transfer.utils.mooncake_transfer_engine import global_te from vllm_ascend.distributed.kv_transfer.utils.utils import ( align_memory, + context_parallel_parameters_check, + get_cp_group, + get_local_remote_block_port_mappings, + get_transfer_mappings, get_transfer_timeout_value, kv_alltoall_and_rearrange, + parallel_info, ) from vllm_ascend.utils import npu_stream_switch @@ -68,7 +80,15 @@ class ReqMeta: remote_te_rpc_port: int | None remote_kv_caches_base_addr: list[int] | None metaserver: str | None - chunk_finish: bool | None + remote_tp_size: int | None + remote_pcp_size: int | None + remote_dcp_size: int | None + chunk_finish: bool = False + prompt_len: int = 0 + trans_count: int = 0 + remote_cache_tokens: int = 0 + local_computed_tokens: int = 0 + local_transed_tokens: int = 0 @dataclass @@ -100,8 +120,6 @@ class TransferMeta: @dataclass class SendReqInfo: local_block_ids: list[int] - remote_block_ids: list[int] - remote_cache_tokens: int local_transferred_tokens: int local_computed_tokens: int request: "Request" @@ -121,8 +139,6 @@ class SendReqInfo: def unpack(self): return ( self.local_block_ids, - self.remote_block_ids, - self.remote_cache_tokens, self.local_transferred_tokens, self.local_computed_tokens, self.request, @@ -161,8 +177,6 @@ class KVCacheSendingLayerThread(threading.Thread): kv_cache_base_addr: list[int], use_mla: bool, block_len: list[int], - decode_tp_size: int, - first_kv_cache: torch.Tensor, k_buffer: torch.Tensor, v_buffer: torch.Tensor, resharding_stream: torch.npu.Stream, @@ -178,7 +192,6 @@ class KVCacheSendingLayerThread(threading.Thread): self.use_mla = use_mla self.use_sparse = len(block_len) == 3 self.block_len = block_len - self._decode_tp_size = decode_tp_size self.resharding_stream = resharding_stream self.current_layer = -1 @@ -373,10 +386,10 @@ class KVCacheRecvingLayerThread(threading.Thread): self.done_requests = set() return finished_requests - def update_task(self, req_id): + def update_task(self, req_id, trans_count): with self.lock: self.task_tracker[req_id] += 1 - if self.task_tracker[req_id] == self.pd_head_ratio: + if self.task_tracker[req_id] == trans_count: self.task_tracker.pop(req_id) self.done_requests.add(self.request_map[req_id]) self.request_map.pop(req_id) @@ -411,7 +424,8 @@ class KVCacheRecvingLayerThread(threading.Thread): elif msg[0] == DONE_SENDING_MSG: logger.debug("Got DONE_RECVING_MSG for request %s", msg[1]) request_id = msg[1] - self.update_task(request_id) + trans_count = msg[2] + self.update_task(request_id, trans_count) sock.send_multipart((identity, b"", b"ACK")) else: logger.error("Connection listener got unexpected message %s", msg) @@ -431,6 +445,10 @@ class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata): kv_transfer_params: dict[str, Any], token_ids: list[int] | None = None, chunk_finish: bool = False, + prompt_len: int = 0, + remote_cache_tokens: int = 0, + local_computed_tokens: int = 0, + local_transed_tokens: int = 0, ): self.requests[request_id] = ReqMeta( token_ids=token_ids or [], @@ -442,7 +460,14 @@ class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata): remote_te_rpc_port=kv_transfer_params.get("remote_te_rpc_port"), remote_kv_caches_base_addr=kv_transfer_params.get("remote_kv_caches_base_addr"), metaserver=kv_transfer_params.get("metaserver"), + remote_tp_size=kv_transfer_params.get("remote_tp_size"), + remote_pcp_size=kv_transfer_params.get("remote_pcp_size"), + remote_dcp_size=kv_transfer_params.get("remote_dcp_size"), chunk_finish=chunk_finish, + remote_cache_tokens=remote_cache_tokens, + local_computed_tokens=local_computed_tokens, + prompt_len=prompt_len, + local_transed_tokens=local_transed_tokens, ) @@ -605,7 +630,8 @@ class MooncakeLayerwiseConnectorScheduler: ) if params is not None and params.get("do_remote_prefill"): - local_block_ids = blocks.get_unhashed_block_ids() if num_external_tokens > 0 else [] + local_block_ids = (blocks.get_block_ids()[0]) if num_external_tokens > 0 else [] + remote_cached_tokens = request.num_computed_tokens # Get unhashed blocks to pull from remote. logger.debug( f"MooncakeLayerwiseConnector update_state_after_alloc: add {request.request_id} to need recv queue" @@ -632,6 +658,10 @@ class MooncakeLayerwiseConnectorScheduler: remote_engine_id=self.engine_id, remote_host=self.side_channel_host, remote_port=self.side_channel_port, + remote_tp_size=self.vllm_config.parallel_config.tensor_parallel_size, + remote_pcp_size=self.vllm_config.parallel_config.prefill_context_parallel_size, + remote_dcp_size=self.vllm_config.parallel_config.decode_context_parallel_size, + remote_cached_tokens=remote_cached_tokens, ) future = self.executor.submit( @@ -658,8 +688,6 @@ class MooncakeLayerwiseConnectorScheduler: local_computed_tokens = 0 self._reqs_need_send_layerwise[request.request_id] = SendReqInfo( local_block_ids=local_block_ids, - remote_block_ids=remote_block_ids, - remote_cache_tokens=remote_cache_tokens, local_transferred_tokens=local_transferred_tokens, local_computed_tokens=local_computed_tokens, request=request, @@ -691,11 +719,9 @@ class MooncakeLayerwiseConnectorScheduler: cached_reqs = scheduler_output.scheduled_cached_reqs new_reqs = scheduler_output.scheduled_new_reqs scheduled_spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens - # update local block ids for req_id, new_blocks in zip(cached_reqs.req_ids, cached_reqs.new_block_ids): if req_id in self._reqs_need_send_layerwise and new_blocks is not None: self._reqs_need_send_layerwise[req_id].extend_local_block_ids(new_blocks[0]) - computed_tokens = dict( list(zip(cached_reqs.req_ids, cached_reqs.num_computed_tokens)) + [(x.req_id, x.num_computed_tokens) for x in new_reqs] @@ -703,6 +729,10 @@ class MooncakeLayerwiseConnectorScheduler: for req_id, scheduled_tokens in scheduler_output.num_scheduled_tokens.items(): if req_id in self._reqs_need_send_layerwise: send_req_info = self._reqs_need_send_layerwise[req_id] + # update local transferred tokens + send_req_info.update_transferred_tokens( + round_down(send_req_info.local_computed_tokens, self.block_size) + ) # update local computed tokens, not transfer spec decode tokens spec_decode_tokens = ( len(scheduled_spec_decode_tokens[req_id]) if (req_id in scheduled_spec_decode_tokens) else 0 @@ -714,56 +744,36 @@ class MooncakeLayerwiseConnectorScheduler: def add_tranfer_task(req_id, send_req_info: SendReqInfo, chunk_finish=False): ( local_block_ids, - remote_block_ids, - remote_cache_tokens, - local_transferred_tokens, + local_transed_tokens, local_computed_tokens, request, ) = send_req_info.unpack() - local_trans_block_ids = local_block_ids[ - (local_transferred_tokens // self.block_size) : (local_computed_tokens // self.block_size) - ] - remote_trans_block_ids = remote_block_ids[ - ((local_transferred_tokens - remote_cache_tokens) // self.block_size) : ( - (local_computed_tokens - remote_cache_tokens) // self.block_size - ) - ] - request.kv_transfer_params["remote_block_ids"] = remote_trans_block_ids - assert len(local_trans_block_ids) == len(remote_trans_block_ids), ( - f"len of local trans block ids : {len(local_trans_block_ids)} not equal to " - f"the len of remote trans block ids : {len(remote_trans_block_ids)}" - ) - adjusted_tokens = ( - local_computed_tokens - (self.block_size - 1) if chunk_finish else local_computed_tokens - ) - logger.info( - f"MooncakeLayerwiseConnector scheduler add transfer task: " - f"{req_id=} {local_block_ids=} {remote_block_ids=} " - f"{local_trans_block_ids=} {remote_trans_block_ids=} " - f"local_computed_tokens={adjusted_tokens} " - f"request.all_token_ids={len(request.all_token_ids)}" - ) meta.add_new_req( request_id=req_id, - local_block_ids=local_trans_block_ids, + local_block_ids=local_block_ids, kv_transfer_params=request.kv_transfer_params, token_ids=[], chunk_finish=chunk_finish, + remote_cache_tokens=request.kv_transfer_params.get("remote_cached_tokens"), + prompt_len=len(request.all_token_ids), + local_computed_tokens=local_computed_tokens, + local_transed_tokens=local_transed_tokens, + ) + logger.debug( + f"MooncakeLayerwiseConnector build_connector_meta: {req_id=}" + f"prompt_len={len(request.all_token_ids)} {local_computed_tokens=}" + f"{local_transed_tokens=}" + f"remote_cache_tokens={request.kv_transfer_params.get('remote_cached_tokens')}" + f"{chunk_finish=} {local_block_ids=}" + f"remote_block_ids={request.kv_transfer_params.get('remote_block_ids')}" ) - # update local_transferred_tokens - local_transferred_tokens = (local_computed_tokens // self.block_size) * self.block_size - send_req_info.update_transferred_tokens(local_transferred_tokens) - # no chunk or last chunk - if send_req_info.local_computed_tokens >= len(send_req_info.request.all_token_ids): - send_req_info.update_computed_tokens(send_req_info.local_computed_tokens + self.block_size - 1) - add_tranfer_task(req_id, send_req_info, chunk_finish=True) + # whether chunk finish + chunk_finish = send_req_info.local_computed_tokens >= len(send_req_info.request.all_token_ids) + + add_tranfer_task(req_id, send_req_info, chunk_finish=chunk_finish) + if chunk_finish: self._reqs_need_send_layerwise.pop(req_id) - # chunk - elif (send_req_info.local_computed_tokens // self.block_size) - ( - send_req_info.local_transferred_tokens // self.block_size - ) > 0: - add_tranfer_task(req_id, send_req_info) return meta def _access_metaserver(self, url, message): @@ -796,13 +806,7 @@ class MooncakeLayerwiseConnectorWorker: """Implementation of Worker side methods""" def __init__(self, vllm_config: VllmConfig, engine_id: str): - self._get_prefill_decode_size(vllm_config) os.environ["ASCEND_TRANSFER_TIMEOUT"] = str(get_transfer_timeout_value()) - if self._prefill_tp_size < self._decode_tp_size: - raise ValueError( - f"prefill_tp_size: {self._prefill_tp_size} must be greater than" - f" or equal to the decode_tp_size: {self._decode_tp_size}" - ) if TransferEngine is None: raise RuntimeError("mooncake is not available") @@ -814,11 +818,20 @@ class MooncakeLayerwiseConnectorWorker: self.engine_id = engine_id self.tp_rank = get_tensor_model_parallel_rank() self.tp_size = vllm_config.parallel_config.tensor_parallel_size + self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size + self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0 + self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size + self.dcp_rank = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0 self.tp_group = get_tp_group() + self._decode_tp_size: int | None = None self.kv_caches: dict[str, torch.Tensor] = {} self.side_channel_host = get_ip() self.total_layers = vllm_config.model_config.get_num_layers(vllm_config.parallel_config) self.use_mla = self.vllm_config.model_config.use_mla + if self.use_mla: + self.total_num_kv_heads = 1 + else: + self.total_num_kv_heads = self.vllm_config.model_config.get_total_num_kv_heads() # Handshake base port self.side_channel_port = ( @@ -863,23 +876,6 @@ class MooncakeLayerwiseConnectorWorker: self.k_buffer: torch.Tensor | None = None self.v_buffer: torch.Tensor | None = None - def _get_prefill_decode_size(self, vllm_config: VllmConfig): - # get prefill tp and dp size from extra config - prefill_parallel_config: dict[str, Any] = vllm_config.kv_transfer_config.get_from_extra_config("prefill", {}) - - assert "tp_size" in prefill_parallel_config - self._prefill_tp_size = prefill_parallel_config["tp_size"] - - assert "dp_size" in prefill_parallel_config - self._prefill_dp_size = prefill_parallel_config["dp_size"] - - # get decode tp and dp size from extra config - decode_parallel_config: dict[str, Any] = vllm_config.kv_transfer_config.get_from_extra_config("decode", {}) - assert "tp_size" in decode_parallel_config - self._decode_tp_size = decode_parallel_config["tp_size"] - assert "dp_size" in decode_parallel_config - self._decode_dp_size = decode_parallel_config["dp_size"] - def create_kv_buffer(self, first_kv_cache): if self.pd_head_ratio > 1: # regesit kv buffer for tp inequal @@ -977,8 +973,6 @@ class MooncakeLayerwiseConnectorWorker: kv_cache_base_addr=self.kv_caches_base_addr, use_mla=self.use_mla, block_len=self.block_len, - decode_tp_size=self._decode_tp_size, - first_kv_cache=first_kv_cache, k_buffer=self.k_buffer, v_buffer=self.v_buffer, resharding_stream=self.resharding_stream, @@ -1009,9 +1003,120 @@ class MooncakeLayerwiseConnectorWorker: else set() ) if len(done_recving) > 0: - logger.info("Number of completed KV cache recv requests: %d, receive requests: %d", 0, len(done_recving)) + logger.info( + f"Number of completed KV cache recv requests: {len(done_recving)}, receive requests: {done_recving}" + ) return set(), done_recving + # {(ip, port)]: {local_block_ids: [], remote_block_ids: {}}} + def _get_kv_split_metadata(self, req_meta, req_idx, req_id): + remote_pcp_size = req_meta.remote_pcp_size + remote_dcp_size = req_meta.remote_dcp_size + remote_tp_size = req_meta.remote_tp_size + remote_hosts = [req_meta.remote_host] + remote_port = req_meta.remote_port + local_transed_tokens = max(req_meta.remote_cache_tokens, req_meta.local_transed_tokens) + # local_transed_tokens tokens that have already been transmitted on the local side + local_computed_tokens = req_meta.local_computed_tokens + prompt_len = req_meta.prompt_len + p_parallel_info = parallel_info( + tp_size=self.tp_size, + pcp_size=self.pcp_size, + dcp_size=self.dcp_size, + pd_head_ratio=self.pd_head_ratio, + use_mla=self.use_mla, + ) + d_parallel_info = parallel_info( + tp_size=remote_tp_size, + pcp_size=remote_pcp_size, + dcp_size=remote_dcp_size, + pd_head_ratio=self.pd_head_ratio, + use_mla=self.use_mla, + ) + cp_size = self.pcp_size * self.dcp_size + # to_trans_idx all tokens that have been processed up to the current step + if req_meta.chunk_finish: + to_trans_idx = math.ceil(local_computed_tokens / self.block_size) + else: + to_trans_idx = math.floor(local_computed_tokens / self.block_size) + prompt_block_size = math.ceil(prompt_len / self.block_size) + # + num_local_blocks = prompt_block_size // cp_size + int( + (prompt_block_size % cp_size) > (self.pcp_rank * self.dcp_size + self.dcp_rank) + ) + already_send_blocks = to_trans_idx // cp_size + int( + (to_trans_idx % cp_size) > (self.pcp_rank * self.dcp_size + self.dcp_rank) + ) + if num_local_blocks == already_send_blocks: + req_meta.chunk_finish = True + transed_idx = math.floor(local_transed_tokens / self.block_size) + + p_cp_group = get_cp_group(self.tp_size, self.total_num_kv_heads, self.dcp_size) + d_cp_group = get_cp_group(remote_tp_size, self.total_num_kv_heads, remote_dcp_size) + logger.debug(f"Compute cp group for P&D {req_id=} {p_cp_group=} {d_cp_group=}") + + cp_ratio = len(p_cp_group) // len(d_cp_group) + if cp_ratio == 0: + selected_p_cp_groups = p_cp_group + selected_d_cp_groups = d_cp_group + else: + x = req_idx % cp_ratio + start = x * len(d_cp_group) + selected_p_cp_groups = p_cp_group[start : (start + len(d_cp_group))] + selected_d_cp_groups = d_cp_group + assert len(selected_p_cp_groups) == len(selected_d_cp_groups) + + p_head_group_rank = (self.tp_rank - self.dcp_rank) // self.dcp_size + selected_p_cp_group = [] + selected_d_cp_group = [] + for idx, cp_group in enumerate(selected_p_cp_groups): + if p_head_group_rank in cp_group: # Check whether the rank is in selected_p_cp_groups + selected_p_cp_group = cp_group + selected_d_cp_group = selected_d_cp_groups[idx] + if len(selected_p_cp_group) == 0: + return {} + + logger.debug( + f"MooncakeLayerwiseConnector _get_kv_split_metadata {req_id=} " + f"P-side selected head_group cp group: {selected_p_cp_group}, " + f"D-side selected head_group cp group: {selected_d_cp_group}" + ) + + context_parallel_parameters_check( + remote_pcp_size, remote_dcp_size, p_parallel_info, d_parallel_info, self.total_num_kv_heads + ) + p_rank_block_mapping, d_block_rank_mapping, pd_head_mapping, d_trans_count_mapping = ( + get_local_remote_block_port_mappings( + to_trans_idx, + p_parallel_info, + d_parallel_info, + remote_hosts, + remote_port, + selected_p_cp_group, + selected_d_cp_group, + prompt_len, + self.block_size, + req_meta, + self.total_num_kv_heads, + req_id, + ) + ) + transfer_mappings = get_transfer_mappings( + p_rank_block_mapping, + d_block_rank_mapping, + pd_head_mapping, + d_trans_count_mapping, + req_meta, + p_parallel_info, + req_id, + transed_idx, + to_trans_idx, + self.tp_rank, + self.pcp_rank, + self.dcp_rank, + ) + return transfer_mappings + def start_load_kv(self, metadata: MooncakeLayerwiseConnectorMetadata): """Start loading KV blocks from remote engine.""" self.current_layer = 0 @@ -1023,31 +1128,29 @@ class MooncakeLayerwiseConnectorWorker: self.kv_recv_layer_thread.task_tracker[external_req_id] = 0 self.kv_recv_layer_thread.request_map[external_req_id] = req_id elif self.vllm_config.kv_transfer_config.is_kv_producer: - # select req to send - if self.use_mla or self.use_sparse: - num_need_send = self._decode_tp_size - else: - num_kv_head = self.vllm_config.model_config.hf_config.num_key_value_heads - if self.tp_size <= num_kv_head: - num_need_send = self.tp_size - else: - num_need_send = self._decode_tp_size if self._decode_tp_size >= num_kv_head else num_kv_head - num_replica_groups = self.tp_size // num_need_send if self.tp_size >= num_need_send else 1 - replica_group_idx = self.tp_rank % num_replica_groups - req_ids = sorted(list(metadata.requests.keys())) - selected_req_ids = [ - req_id for i, req_id in enumerate(req_ids) if i % num_replica_groups == replica_group_idx - ] - request_ids = list(metadata.requests.keys()) - for req_id in request_ids: - if req_id not in selected_req_ids: - metadata.requests.pop(req_id) + # update trans info + update_metadata = {} + for req_idx, (req_id, req_meta) in enumerate(metadata.requests.items()): + self._decode_tp_size = req_meta.remote_tp_size + transfer_mappings = self._get_kv_split_metadata(req_meta, req_idx, req_id) + assert len(transfer_mappings) <= 1, f"Not support add mutil transfer task for req_id:{req_id}" + update_req_meta = copy.deepcopy(req_meta) + for (host, port), block_dict in transfer_mappings.items(): + update_req_meta.remote_host = host + update_req_meta.remote_port = port + update_req_meta.local_block_ids = block_dict["local_block_ids"] + update_req_meta.remote_block_ids = block_dict["remote_block_ids"] + update_req_meta.trans_count = block_dict["trans_count"] + update_metadata[req_id] = update_req_meta + metadata.requests = {} + for req_id, req_meta in update_metadata.items(): + metadata.requests[req_id] = update_metadata[req_id] # update send task trans block info if self.pd_head_ratio != 1: send_task = metadata.send_task send_task.rearrange_block_ids = sorted( - {block_id for req_id in selected_req_ids for block_id in metadata.requests[req_id].local_block_ids} + {block_id for req_id in metadata.requests for block_id in metadata.requests[req_id].local_block_ids} ) device = self.k_buffer.device # type: ignore @@ -1070,7 +1173,7 @@ class MooncakeLayerwiseConnectorWorker: ) -> None: """MooncakeLayerwiseConnector does not save explicitly.""" if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys(): - # enable decode prefix cache + # get reshape and cache event if self.use_mla or self.use_sparse: reshape_cache_event = attn_metadata[layer_name].reshape_cache_event else: @@ -1156,59 +1259,48 @@ class MooncakeLayerwiseConnectorWorker: return sock def update_decoder_info(self, req_id, req_meta): - req_meta_update = copy.deepcopy(req_meta) - if self.use_mla or self.use_sparse: - pd_tp_ratio = self.tp_size // self._decode_tp_size - req_meta_update.remote_port = ( - req_meta_update.remote_port + (self.tp_rank // pd_tp_ratio) % self._decode_tp_size - ) - else: - req_meta_update.remote_port = ( - req_meta_update.remote_port + (self.tp_rank // self.pd_tp_ratio) % self._decode_tp_size - ) if ( - req_meta_update.remote_engine_id not in self.remote_kv_caches_base_addr - or req_meta_update.remote_port not in self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id] + req_meta.remote_engine_id not in self.remote_kv_caches_base_addr + or req_meta.remote_port not in self.remote_kv_caches_base_addr[req_meta.remote_engine_id] ): try: encoded_data = self.encoder.encode((GET_META_MSG, req_id)) - sock = self._get_remote_socket(req_meta_update.remote_host, req_meta_update.remote_port) - path = f"{req_meta_update.remote_host}:{req_meta_update.remote_port}" + sock = self._get_remote_socket(req_meta.remote_host, req_meta.remote_port) + path = f"{req_meta.remote_host}:{req_meta.remote_port}" ensure_zmq_send(sock, encoded_data, path) metadata_bytes = ensure_zmq_recv(sock, self.remote_poller, path) agent_meta = self.decoder.decode(metadata_bytes) except Exception as e: logger.error( - f"Query to port and kv base addr for request {req_id} from " - f"{req_meta_update.remote_host}:{req_meta_update.remote_port} fail with error: {e}" + f"Query to port and kv base addr for request {req_id}" + f"from {req_meta.remote_host}:{req_meta.remote_port}" + f"fail with error: {e}" ) - assert req_meta_update.remote_engine_id != self.engine_id, ( - f"Conflict engine id {req_meta_update.remote_engine_id} with local engine id {self.local_engine_id}." + assert req_meta.remote_engine_id != self.engine_id, ( + f"Conflict engine id {req_meta.remote_engine_id} with local engine id {self.local_engine_id}." ) - self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id][req_meta_update.remote_port] = ( + self.remote_kv_caches_base_addr[req_meta.remote_engine_id][req_meta.remote_port] = ( agent_meta.kv_caches_base_addr ) - self.remote_te_port[req_meta_update.remote_engine_id][req_meta_update.remote_port] = agent_meta.te_rpc_port + self.remote_te_port[req_meta.remote_engine_id][req_meta.remote_port] = agent_meta.te_rpc_port logger.info( - f"Query to port and kv base addr for request {req_id} from " - f"{req_meta_update.remote_host}:{req_meta_update.remote_port} success " - f"{agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}" + f"Query to port and kv base addr for request {req_id}" + f"from {req_meta.remote_host}:{req_meta.remote_port}" + f"success {agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}" ) if self.pd_head_ratio > 1: # for tp inequal, pre-create link to prevent alltoall out of memory - session_id = f"{req_meta_update.remote_host}:{agent_meta.te_rpc_port}" + session_id = f"{req_meta.remote_host}:{agent_meta.te_rpc_port}" ret = self.engine.batch_transfer_sync_write( session_id, [self.kv_caches_base_addr[0]], [agent_meta.kv_caches_base_addr[0]], [128] ) if ret < 0: logger.error(f"Mooncake transfer failed to create link to device {session_id}") - req_meta_update.remote_te_rpc_port = self.remote_te_port[req_meta_update.remote_engine_id][ - req_meta_update.remote_port + req_meta.remote_te_rpc_port = self.remote_te_port[req_meta.remote_engine_id][req_meta.remote_port] + req_meta.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[req_meta.remote_engine_id][ + req_meta.remote_port ] - req_meta_update.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id][ - req_meta_update.remote_port - ] - return req_meta_update + return req_meta def send_done_send_signal(self, req_id, req_meta): external_req_id = get_external_request_id(req_id) @@ -1221,7 +1313,7 @@ class MooncakeLayerwiseConnectorWorker: try: path = make_zmq_path("tcp", req_meta.remote_host, req_meta.remote_port) msg_encoder = msgspec.msgpack.Encoder() - encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id)) + encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id, req_meta.trans_count)) with zmq_ctx(zmq.REQ, path) as sock: # type: ignore ensure_zmq_send(sock, encoded_data, f"{req_meta.remote_host}:{req_meta.remote_port}") ack = sock.recv() diff --git a/vllm_ascend/distributed/kv_transfer/utils/utils.py b/vllm_ascend/distributed/kv_transfer/utils/utils.py index 4886f4ac..19ef0ed0 100644 --- a/vllm_ascend/distributed/kv_transfer/utils/utils.py +++ b/vllm_ascend/distributed/kv_transfer/utils/utils.py @@ -1,7 +1,12 @@ +import math import os +from collections import defaultdict +from dataclasses import dataclass +from typing import Any import torch import torch.distributed as dist +from vllm.logger import logger from vllm_ascend.distributed.parallel_state import get_p_tp_group @@ -50,3 +55,241 @@ def get_transfer_timeout_value(): hccl_rdma_timeout = int(os.getenv("HCCL_RDMA_TIMEOUT", "20")) # type: ignore hccl_rdma_retry_cnt = int(os.getenv("HCCL_RDMA_RETRY_CNT", "7")) # type: ignore return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 + 3000) + + +@dataclass +class parallel_info: + tp_size: int + pcp_size: int + dcp_size: int + use_mla: bool + pd_head_ratio: int + + +def get_cp_group(tp: int, heads: int, dcp: int): + # Partition the second dimension of [pcp][head_group][dcp] to obtain a complete head group + # head_group is all blocks for request in the same head + # tp8 dcp2 heads4 return[[0,1,2,3]] + # tp8 dcp1 heads4 return[[0,2,4,6],[1,3,5,7]] + step = tp // heads + if step == 0: + return [[i for i in range(tp // dcp)]] + else: + return [ + set([k // dcp for h in range(heads) for k in range(h * step + i * dcp, h * step + (i + 1) * dcp)]) + for i in range(step // dcp) + ] + + +def context_parallel_parameters_check( + remote_pcp_size: int, + remote_dcp_size: int, + p_parallel_info: parallel_info, + d_parallel_info: parallel_info, + total_num_kv_heads: int, +): + # Check whether the pcp–dcp ratio is supported + assert (p_parallel_info.pcp_size * p_parallel_info.dcp_size) % (remote_pcp_size * remote_dcp_size) == 0 + if not p_parallel_info.use_mla: + p_node_heads_per_rank = math.ceil(total_num_kv_heads / p_parallel_info.tp_size) + d_node_heads_per_rank = math.ceil(total_num_kv_heads / d_parallel_info.dcp_size) + assert d_node_heads_per_rank % p_node_heads_per_rank == 0 + + +def get_tp_rank_head_mapping(num_key_value_heads: int, tp_size: int): + # Get the head_idx corresponding to the tp_rank, {tp_rank:[head_indx]} + mapping = {} + if tp_size <= num_key_value_heads: + if num_key_value_heads % tp_size != 0: + raise ValueError(f"Number of heads ({num_key_value_heads}) cannot be evenly divided by TP ({tp_size}).") + + heads_per_rank = num_key_value_heads // tp_size + + for rank in range(tp_size): + start_idx = rank * heads_per_rank + end_idx = start_idx + heads_per_rank + mapping[rank] = list(range(start_idx, end_idx)) + else: + if tp_size % num_key_value_heads != 0: + raise ValueError(f"Number of heads ({num_key_value_heads}) cannot be evenly divided by TP ({tp_size}).") + ranks_per_head = tp_size // num_key_value_heads + for rank in range(tp_size): + head_idx = rank // ranks_per_head + mapping[rank] = [head_idx] + return mapping + + +def get_head_group_mapping(num_key_value_heads: int, tp_size: int, num_groups: int, select_cp_group: list[int]): + # Get the mapping dictionary, where the key is head_group_rank and the value is head_idx + if tp_size % num_groups != 0: + raise ValueError( + f"Total number of devices ({tp_size}) cannot be divided by the number of groups ({num_groups})." + ) + ranks_per_group = tp_size // num_groups + tp_mapping = get_tp_rank_head_mapping(num_key_value_heads, tp_size) + group_mapping = {} + for group_rank in range(num_groups): + if group_rank in select_cp_group: + start_rank = group_rank * ranks_per_group + end_rank = start_rank + ranks_per_group + heads_set = set() + + for rank in range(start_rank, end_rank): + heads_set.update(tp_mapping[rank]) + group_mapping[group_rank] = sorted(list(heads_set)) + return group_mapping + + +def get_local_remote_block_port_mappings( + to_trans_idx: int, + p_parallel_info: parallel_info, + d_parallel_info: parallel_info, + d_hosts: list[str], + d_port: int, + selected_p_cp_group: list[int], + selected_d_cp_group: list[int], + prompt_len: int, + block_size: int, + req_meta, + total_num_kv_heads: int, + req_id: str, +): + p_head_group_size = p_parallel_info.tp_size // p_parallel_info.dcp_size + d_head_group_size = d_parallel_info.tp_size // d_parallel_info.dcp_size + world_size = d_parallel_info.pcp_size * d_head_group_size * d_parallel_info.dcp_size + # Compute which logic_block_idx corresponds to each tp_rank + p_rank_block_mapping: list[list[list[list[int]]]] = [ + [[[] for _ in range(p_parallel_info.dcp_size)] for _ in range(p_head_group_size)] + for _ in range(p_parallel_info.pcp_size) + ] + for logic_block_idx in range(to_trans_idx): + pcp_rank = (logic_block_idx // p_parallel_info.dcp_size) % p_parallel_info.pcp_size + dcp_rank = logic_block_idx % p_parallel_info.dcp_size + for p_head_group_rank in range(p_head_group_size): + if p_head_group_rank in selected_p_cp_group: + p_rank_block_mapping[pcp_rank][p_head_group_rank][dcp_rank].append(logic_block_idx) + + # Find the remote device that holds the logic_block_idx + d_block_rank_mapping: dict[int, dict[int, dict[str, Any]]] = defaultdict(lambda: defaultdict(dict)) + for logic_block_idx in range(to_trans_idx): + pcp_rank = (logic_block_idx // d_parallel_info.dcp_size) % d_parallel_info.pcp_size + for d_head_group_rank in range(d_head_group_size): + if d_head_group_rank in selected_d_cp_group: + dcp_rank = logic_block_idx % d_parallel_info.dcp_size + world_rank = ( + pcp_rank * d_head_group_size * d_parallel_info.dcp_size + + d_head_group_rank * d_parallel_info.dcp_size + + dcp_rank + ) + world_size = d_parallel_info.pcp_size * d_head_group_size * d_parallel_info.dcp_size + host = d_hosts[(len(d_hosts) * world_rank) // world_size] + port = d_port + world_rank + block_idx = (logic_block_idx - (pcp_rank * d_parallel_info.pcp_size + dcp_rank)) // ( + d_parallel_info.pcp_size * d_parallel_info.dcp_size + ) + d_block_rank_mapping[logic_block_idx][d_head_group_rank] = { + "pcp_rank": pcp_rank, + "dcp_rank": dcp_rank, + "host": host, + "port": port, + "block_idx": block_idx, + } + # Get how many times each device should receive done_single for this request + d_trans_count_mapping = {} + trans_block_size = math.ceil(prompt_len / block_size) # Total number of blocks + transed_block_size = math.ceil(req_meta.remote_cache_tokens / block_size) # Number of prefix cache hit blocks + d_cp_size = d_parallel_info.pcp_size * d_parallel_info.dcp_size + for d_pcp_rank in range(d_parallel_info.pcp_size): + for d_head_group_rank in range(d_head_group_size): + for d_dcp_rank in range(d_parallel_info.dcp_size): + if trans_block_size >= (p_parallel_info.pcp_size * p_parallel_info.dcp_size): + trans_count = (p_parallel_info.pcp_size * p_parallel_info.dcp_size) // d_cp_size + else: + current_rank_idx = d_pcp_rank * d_parallel_info.dcp_size + d_dcp_rank + total_global_blocks = transed_block_size + trans_block_size + + target_total_count = total_global_blocks // d_cp_size + if current_rank_idx < (total_global_blocks % d_cp_size): + target_total_count += 1 + + prev_processed_count = transed_block_size // d_cp_size + if current_rank_idx < (transed_block_size % d_cp_size): + prev_processed_count += 1 + + trans_count = target_total_count - prev_processed_count + world_rank = ( + d_pcp_rank * d_head_group_size * d_parallel_info.dcp_size + + d_head_group_rank * d_parallel_info.dcp_size + + d_dcp_rank + ) + host = d_hosts[(len(d_hosts) * world_rank) // world_size] + port = d_port + world_rank + d_trans_count_mapping[(host, port)] = trans_count * p_parallel_info.pd_head_ratio + + # Compute the mapping between local and remote head_group_rank + p_tp_rank_head_mapping = get_head_group_mapping( + total_num_kv_heads, p_parallel_info.tp_size, p_head_group_size, selected_p_cp_group + ) + d_tp_rank_head_mapping = get_head_group_mapping( + total_num_kv_heads, d_parallel_info.tp_size, d_head_group_size, selected_d_cp_group + ) + head_to_d_groups = defaultdict(set) + for d_rank, heads in d_tp_rank_head_mapping.items(): + for head in heads: + head_to_d_groups[head].add(d_rank) + pd_head_mapping = {} + for p_rank, p_heads in p_tp_rank_head_mapping.items(): + target_d_ranks = set() + for head in p_heads: + if head in head_to_d_groups: + target_d_ranks.update(head_to_d_groups[head]) + else: + logger.info(f"Warning: Head {head} exists in P but not in D mapping.") + pd_head_mapping[p_rank] = sorted(list(target_d_ranks)) + logger.debug( + f"MooncakeLayerwiseConnector _get_kv_split_metadata {req_id=} " + f"P-side logic_block to rank mapping: {p_rank_block_mapping}, " + f"D-side logic_block to rank mapping: {d_block_rank_mapping}, " + f"P&D head_group_rank mapping: {pd_head_mapping}" + ) + return p_rank_block_mapping, d_block_rank_mapping, pd_head_mapping, d_trans_count_mapping + + +def get_transfer_mappings( + p_rank_block_mapping: list[list[list[list[int]]]], + d_block_rank_mapping: dict[int, dict[int, dict[str, Any]]], + pd_head_mapping: dict[int, set], + d_trans_count_mapping: dict[tuple[str, int], int], + req_meta, + p_parallel_info: parallel_info, + req_id: str, + transed_idx: int, + to_trans_idx: int, + tp_rank: int, + pcp_rank: int, + dcp_rank: int, +): + transfer_mappings: dict[tuple[str, int], dict[str, Any]] = {} + p_head_group_rank = (tp_rank - dcp_rank) // p_parallel_info.dcp_size + p_block_idxs: list[int] = p_rank_block_mapping[pcp_rank][p_head_group_rank][dcp_rank] + for p_block_idx, logic_block_idx in enumerate(p_block_idxs): + if logic_block_idx < transed_idx or logic_block_idx >= to_trans_idx: + continue + for d_head_group_rank in pd_head_mapping[p_head_group_rank]: + p_block_id = req_meta.local_block_ids[p_block_idx] + remote_host = d_block_rank_mapping[logic_block_idx][d_head_group_rank]["host"] + remote_port = d_block_rank_mapping[logic_block_idx][d_head_group_rank]["port"] + d_block_idx = d_block_rank_mapping[logic_block_idx][d_head_group_rank]["block_idx"] + d_block_id = req_meta.remote_block_ids[d_block_idx] + if (remote_host, remote_port) not in transfer_mappings: + transfer_mappings[(remote_host, remote_port)] = { + "local_block_ids": [], + "remote_block_ids": [], + "trans_count": 0, + } + transfer_mappings[(remote_host, remote_port)]["local_block_ids"].append(p_block_id) + transfer_mappings[(remote_host, remote_port)]["remote_block_ids"].append(d_block_id) + for (host, port), block_dict in transfer_mappings.items(): + block_dict["trans_count"] = d_trans_count_mapping[(host, port)] + logger.debug(f"MooncakeLayerwiseConnector Request {req_id} transfer tasks: {transfer_mappings}") + return transfer_mappings