diff --git a/tests/ut/attention/test_attention_cp.py b/tests/ut/attention/test_attention_cp.py index 487d4169..3cdbeb6b 100644 --- a/tests/ut/attention/test_attention_cp.py +++ b/tests/ut/attention/test_attention_cp.py @@ -232,15 +232,17 @@ class TestAscendAttentionCPImpl(TestBase): self.assertEqual(value.shape[1], num_heads) self.assertEqual(value.shape[2], head_size) + @patch('torch_npu.Event', create=True) @patch('torch_npu._npu_reshape_and_cache') @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) - def test_reshape_and_cache(self, mock_npu_reshape_and_cache): + def test_reshape_and_cache(self, mock_event_class, mock_npu_reshape_and_cache): num_tokens = 4 block_num = 100 block_size = 128 num_heads = 1 head_size = 128 self.impl.head_size = head_size + self.impl.is_kv_producer = False kv_cache = (torch.randn(block_num, block_size, num_heads, head_size), torch.randn(block_num, block_size, num_heads, head_size)) diff --git a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py index f8266580..bdb1b02f 100644 --- a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py +++ b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py @@ -61,8 +61,6 @@ class TestKVCacheSendingLayerThread(unittest.TestCase): 6000], # 2 * total_layers use_mla=True, block_len=[1024, 2048], - decode_tp_size=1, - first_kv_cache=self.first_kv_cache, k_buffer=self.fake_k_buffer, v_buffer=self.fake_v_buffer, resharding_stream=fake_resharding_stream, @@ -70,6 +68,9 @@ class TestKVCacheSendingLayerThread(unittest.TestCase): self.req_meta_base = ReqMeta( local_block_ids=[5, 8], + remote_tp_size = 8, + remote_pcp_size = 1, + remote_dcp_size = 1, token_ids=[1, 2, 3], remote_block_ids=[10, 20], remote_engine_id="remote_engine", @@ -112,8 +113,6 @@ class TestKVCacheSendingLayerThread(unittest.TestCase): kv_cache_base_addr=[1111, 2222, 3333, 4444], use_mla=False, block_len=[64], - decode_tp_size=1, - first_kv_cache=self.first_kv_cache, k_buffer=self.fake_k_buffer, v_buffer=self.fake_v_buffer, resharding_stream=fake_resharding_stream, @@ -242,12 +241,12 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase): th.task_tracker["reqX"] = 0 th.request_map["reqX"] = "reqX" - th.update_task("reqX") + th.update_task("reqX", 2) with th.lock: self.assertIn("reqX", th.task_tracker) self.assertNotIn("reqX", th.done_requests) - th.update_task("reqX") + th.update_task("reqX", 2) with th.lock: self.assertNotIn("reqX", th.task_tracker) self.assertIn("reqX", th.done_requests) @@ -284,7 +283,7 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase): dec_inst = MagicMock() dec_inst.decode.side_effect = [ (GET_META_MSG, ), - (DONE_SENDING_MSG, "reqA"), + (DONE_SENDING_MSG, "reqA", 1), (b"weird_msg", ), ] mock_Decoder.return_value = dec_inst @@ -339,21 +338,11 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase): finished = th.get_and_clear_finished_requests() self.assertIn("reqA", finished) - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger" - ) - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.get_ip", - return_value="127.0.0.1") - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Decoder" - ) - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Encoder" - ) - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.zmq_ctx" - ) + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger") + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.get_ip", return_value="127.0.0.1") + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Decoder") + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Encoder") + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.zmq_ctx") def test_run_loop_pd_head_ratio_gt1_requires_multiple_done( self, mock_zmq_ctx, mock_Encoder, mock_Decoder, _mock_get_ip, _mock_logger): @@ -364,8 +353,8 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase): dec_inst = MagicMock() dec_inst.decode.side_effect = [ - (DONE_SENDING_MSG, "reqB"), - (DONE_SENDING_MSG, "reqB"), + (DONE_SENDING_MSG, "reqB", 2), + (DONE_SENDING_MSG, "reqB", 2), ] mock_Decoder.return_value = dec_inst @@ -373,25 +362,26 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase): sock.recv_multipart.side_effect = [ [b"ID", b"PAY1"], [b"ID", b"PAY2"], - SystemExit, + SystemExit, # 退出循环 ] cm = MagicMock() cm.__enter__.return_value = sock mock_zmq_ctx.return_value = cm - th = KVCacheRecvingLayerThread(tp_rank=0, - side_channel_port=5555, - tp_size=2, - pd_head_ratio=2, - local_engine_id="engineY", - metadata=self.meta, - ready_event=self.ready_event) + th = KVCacheRecvingLayerThread( + tp_rank=0, + side_channel_port=5555, + tp_size=2, + pd_head_ratio=2, + local_engine_id="engineY", + metadata=self.meta, + ready_event=self.ready_event + ) with th.lock: th.task_tracker["reqB"] = 0 th.request_map["reqB"] = "reqB" with self.assertRaises(SystemExit): th.run() - finished = th.get_and_clear_finished_requests() self.assertIn("reqB", finished) @@ -441,6 +431,7 @@ class MockRequest: self.kv_transfer_params = kv_transfer_params or {} self.status = status or "running" self.output_token_ids = [101, 102] + self.num_computed_tokens = 0 self.all_token_ids = list(self.prompt_token_ids) @@ -565,7 +556,8 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase): req = MockRequest("req_u1", prompt_token_ids=list(range(24)), kv_transfer_params={"do_remote_prefill": True}) - blocks = _MockBlocks(unhashed=[4, 5, 6]) + req.num_computed_tokens = 0 + blocks = _MockBlocks(unhashed=[4, 5, 6], block_ids_tuple=([4, 5, 6], )) self.scheduler.update_state_after_alloc(req, blocks, @@ -592,7 +584,6 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase): info = self.scheduler._reqs_need_send_layerwise["req_u2"] self.assertEqual(info.local_block_ids, [7, 8, 9]) self.assertIs(info.request, req) - self.assertEqual(info.remote_block_ids, []) def test_build_connector_meta_consumes_reqs_need_recv_and_clears(self): self.scheduler.vllm_config.kv_transfer_config.is_kv_consumer = True @@ -663,12 +654,13 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase): send_req_info.update_computed_tokens = MagicMock() send_req_info.update_transferred_tokens = MagicMock() send_req_info.unpack = MagicMock( - return_value=(send_req_info.local_block_ids, - send_req_info.remote_block_ids, - send_req_info.remote_cache_tokens, - send_req_info.local_transferred_tokens, - send_req_info.local_computed_tokens, - send_req_info.request)) + return_value=( + send_req_info.local_block_ids, + send_req_info.local_transferred_tokens, + send_req_info.local_computed_tokens, + send_req_info.request + ) + ) self.scheduler._reqs_need_send_layerwise["req_b3"] = send_req_info out = _MockSchedulerOutput( @@ -920,6 +912,11 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase): self.vllm_config = MockVllmConfig() self.engine_id = "test_engine" self.kv_caches = {"layer1": (MagicMock(), MagicMock())} + self.vllm_config.parallel_config.tensor_parallel_size = 1 + self.vllm_config.parallel_config.prefill_context_parallel_size = 1 + self.vllm_config.parallel_config.decode_context_parallel_size = 1 + self.vllm_config.parallel_config.data_parallel_rank = 0 + self.vllm_config.kv_transfer_config.kv_port = 1234 def tearDown(self): for p in self.patches: @@ -956,4 +953,4 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase): self.engine_id) worker.register_kv_caches(mla_caches) self.assertTrue(worker.use_mla) - self.assertEqual(len(worker.block_len), 2) + self.assertEqual(len(worker.block_len), 2) \ No newline at end of file diff --git a/tests/ut/kv_connector/utils.py b/tests/ut/kv_connector/utils.py index 582830d8..6a560e80 100644 --- a/tests/ut/kv_connector/utils.py +++ b/tests/ut/kv_connector/utils.py @@ -8,14 +8,11 @@ from typing import Any, Optional import torch from vllm import SamplingParams -from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, - ModelConfig, SchedulerConfig, VllmConfig) +from vllm.config import CacheConfig, DeviceConfig, KVTransferConfig, ModelConfig, SchedulerConfig, VllmConfig from vllm.utils.hashing import sha256 -from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, - init_none_hash) +from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash from vllm.v1.core.sched.scheduler import Scheduler -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec) +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request from vllm.v1.structured_output import StructuredOutputManager @@ -37,14 +34,10 @@ def assert_scheduler_empty(scheduler: Scheduler): assert len(scheduler.encoder_cache_manager.cached) == 0 # KVCache Manager. - assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. - req_to_blocks) == 0 - assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. - num_cached_block) == 0 - num_free_blocks = ( - scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) - assert num_free_blocks == ( - scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) + assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks) == 0 + assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].num_cached_block) == 0 + num_free_blocks = scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks + assert num_free_blocks == (scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) # NOTE(rob): just the ref count on blocks will be 0. The hash # value, etc will remain since we lazily evict for prefix cache. @@ -63,8 +56,7 @@ def create_vllm_config( max_num_batched_tokens=max_num_batched_tokens, max_model_len=max_num_batched_tokens, ) - fake_weight_path = os.path.join(os.path.dirname(__file__), "..", - "fake_weight") + fake_weight_path = os.path.join(os.path.dirname(__file__), "..", "fake_weight") model_config = ModelConfig( model=fake_weight_path, skip_tokenizer_init=True, @@ -77,14 +69,14 @@ def create_vllm_config( cache_dtype="auto", enable_prefix_caching=True, ) - kv_transfer_config = KVTransferConfig( - kv_connector="MooncakeConnectorV1", - kv_role="kv_both") - return VllmConfig(scheduler_config=scheduler_config, - model_config=model_config, - cache_config=cache_config, - kv_transfer_config=kv_transfer_config, - device_config=DeviceConfig("cpu")) + kv_transfer_config = KVTransferConfig(kv_connector="MooncakeConnectorV1", kv_role="kv_both") + return VllmConfig( + scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + kv_transfer_config=kv_transfer_config, + device_config=DeviceConfig("cpu"), + ) def create_scheduler( @@ -96,11 +88,7 @@ def create_scheduler( kv_cache_config = KVCacheConfig( num_blocks=num_blocks, # A large number of blocks to hold all requests kv_cache_tensors=[], - kv_cache_groups=[ - KVCacheGroupSpec(['layer'], - FullAttentionSpec(block_size, 1, 1, torch.float16, - False, False)) - ], + kv_cache_groups=[KVCacheGroupSpec(["layer"], FullAttentionSpec(block_size, 1, 1, torch.float16, False, False))], ) vllm_config.cache_config.num_gpu_blocks = num_blocks @@ -138,19 +126,19 @@ def create_request( if do_remote_decode: assert not do_remote_prefill - kv_transfer_params = dict(do_remote_prefill=False, - do_remote_decode=True) + kv_transfer_params = dict(do_remote_prefill=False, do_remote_decode=True) elif do_remote_prefill: - kv_transfer_params = dict(do_remote_prefill=True, - do_remote_decode=False, - remote_engine_id="my-engine-id", - remote_block_ids=list( - range(num_remote_blocks)), - remote_host="my-host", - remote_port=1234, - remote_tp_size=1, - remote_pcp_size=1, - remote_dcp_size=1) + kv_transfer_params = dict( + do_remote_prefill=True, + do_remote_decode=False, + remote_engine_id="my-engine-id", + remote_block_ids=list(range(num_remote_blocks)), + remote_host="my-host", + remote_port=1234, + remote_tp_size=1, + remote_pcp_size=1, + remote_dcp_size=1, + ) max_tokens = 1 if do_remote_decode else max_tokens sampling_params = SamplingParams(max_tokens=max_tokens) @@ -190,10 +178,9 @@ def create_model_runner_output( # Make output data structure. extra_args = {} - from vllm.v1.worker.kv_connector_model_runner_mixin import \ - KVConnectorOutput # type: ignore # noqa - kv_connector_output = KVConnectorOutput(finished_sending=finished_sending, - finished_recving=finished_recving) + from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput # type: ignore # noqa + + kv_connector_output = KVConnectorOutput(finished_sending=finished_sending, finished_recving=finished_recving) extra_args = {"kv_connector_output": kv_connector_output} model_runner_output = ModelRunnerOutput( diff --git a/vllm_ascend/attention/context_parallel/attention_cp.py b/vllm_ascend/attention/context_parallel/attention_cp.py index 8bffb12f..c2f919fe 100644 --- a/vllm_ascend/attention/context_parallel/attention_cp.py +++ b/vllm_ascend/attention/context_parallel/attention_cp.py @@ -743,6 +743,8 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): has_prefill = attn_metadata.num_prefills > 0 if len(kv_cache) > 1: + if self.is_kv_producer: + attn_metadata.reshape_cache_event = torch.npu.Event() if self.key_cache is None: self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] @@ -778,7 +780,8 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): value_cache=self.value_cache, slot_indices=slot_mapping, ) - + if self.is_kv_producer: + attn_metadata.reshape_cache_event.record() return key, value def _gather_global_context_output(self, local_context_attn_output): diff --git a/vllm_ascend/attention/context_parallel/mla_cp.py b/vllm_ascend/attention/context_parallel/mla_cp.py index de1bc5f3..d30ce725 100644 --- a/vllm_ascend/attention/context_parallel/mla_cp.py +++ b/vllm_ascend/attention/context_parallel/mla_cp.py @@ -414,9 +414,13 @@ class AscendMlaCPImpl(AscendMLAImpl): kv_c_normed, k_pe = prefill_k_c_normed, prefill_k_pe prefill_k_c_normed = prefill_k_c_normed.squeeze() slot_mapping = attn_metadata.slot_mapping[self.pcp_size * num_decode_tokens :] + if self.is_kv_producer: + attn_metadata.reshape_cache_event = torch.npu.Event() torch_npu._npu_reshape_and_cache( key=kv_c_normed, value=k_pe, key_cache=kv_cache[0], value_cache=kv_cache[1], slot_indices=slot_mapping ) + if self.is_kv_producer: + attn_metadata.reshape_cache_event.record() prefill_k_nope, prefill_value = ( self.kv_b_proj(prefill_k_c_normed)[0] .view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py index 03584cd3..7ea8698e 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py @@ -23,9 +23,16 @@ import torch_npu import zmq from mooncake.engine import TransferEngine # type: ignore from vllm.config import VllmConfig +from vllm.distributed import get_pcp_group from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole -from vllm.distributed.parallel_state import get_tensor_model_parallel_rank, get_tp_group, get_world_group +from vllm.distributed.parallel_state import ( + get_decode_context_model_parallel_rank, + get_tensor_model_parallel_rank, + get_tp_group, + get_world_group, +) from vllm.logger import logger +from vllm.utils.math_utils import round_down from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig @@ -35,8 +42,13 @@ from vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector import GET_ME from vllm_ascend.distributed.kv_transfer.utils.mooncake_transfer_engine import global_te from vllm_ascend.distributed.kv_transfer.utils.utils import ( align_memory, + context_parallel_parameters_check, + get_cp_group, + get_local_remote_block_port_mappings, + get_transfer_mappings, get_transfer_timeout_value, kv_alltoall_and_rearrange, + parallel_info, ) from vllm_ascend.utils import npu_stream_switch @@ -68,7 +80,15 @@ class ReqMeta: remote_te_rpc_port: int | None remote_kv_caches_base_addr: list[int] | None metaserver: str | None - chunk_finish: bool | None + remote_tp_size: int | None + remote_pcp_size: int | None + remote_dcp_size: int | None + chunk_finish: bool = False + prompt_len: int = 0 + trans_count: int = 0 + remote_cache_tokens: int = 0 + local_computed_tokens: int = 0 + local_transed_tokens: int = 0 @dataclass @@ -100,8 +120,6 @@ class TransferMeta: @dataclass class SendReqInfo: local_block_ids: list[int] - remote_block_ids: list[int] - remote_cache_tokens: int local_transferred_tokens: int local_computed_tokens: int request: "Request" @@ -121,8 +139,6 @@ class SendReqInfo: def unpack(self): return ( self.local_block_ids, - self.remote_block_ids, - self.remote_cache_tokens, self.local_transferred_tokens, self.local_computed_tokens, self.request, @@ -161,8 +177,6 @@ class KVCacheSendingLayerThread(threading.Thread): kv_cache_base_addr: list[int], use_mla: bool, block_len: list[int], - decode_tp_size: int, - first_kv_cache: torch.Tensor, k_buffer: torch.Tensor, v_buffer: torch.Tensor, resharding_stream: torch.npu.Stream, @@ -178,7 +192,6 @@ class KVCacheSendingLayerThread(threading.Thread): self.use_mla = use_mla self.use_sparse = len(block_len) == 3 self.block_len = block_len - self._decode_tp_size = decode_tp_size self.resharding_stream = resharding_stream self.current_layer = -1 @@ -373,10 +386,10 @@ class KVCacheRecvingLayerThread(threading.Thread): self.done_requests = set() return finished_requests - def update_task(self, req_id): + def update_task(self, req_id, trans_count): with self.lock: self.task_tracker[req_id] += 1 - if self.task_tracker[req_id] == self.pd_head_ratio: + if self.task_tracker[req_id] == trans_count: self.task_tracker.pop(req_id) self.done_requests.add(self.request_map[req_id]) self.request_map.pop(req_id) @@ -411,7 +424,8 @@ class KVCacheRecvingLayerThread(threading.Thread): elif msg[0] == DONE_SENDING_MSG: logger.debug("Got DONE_RECVING_MSG for request %s", msg[1]) request_id = msg[1] - self.update_task(request_id) + trans_count = msg[2] + self.update_task(request_id, trans_count) sock.send_multipart((identity, b"", b"ACK")) else: logger.error("Connection listener got unexpected message %s", msg) @@ -431,6 +445,10 @@ class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata): kv_transfer_params: dict[str, Any], token_ids: list[int] | None = None, chunk_finish: bool = False, + prompt_len: int = 0, + remote_cache_tokens: int = 0, + local_computed_tokens: int = 0, + local_transed_tokens: int = 0, ): self.requests[request_id] = ReqMeta( token_ids=token_ids or [], @@ -442,7 +460,14 @@ class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata): remote_te_rpc_port=kv_transfer_params.get("remote_te_rpc_port"), remote_kv_caches_base_addr=kv_transfer_params.get("remote_kv_caches_base_addr"), metaserver=kv_transfer_params.get("metaserver"), + remote_tp_size=kv_transfer_params.get("remote_tp_size"), + remote_pcp_size=kv_transfer_params.get("remote_pcp_size"), + remote_dcp_size=kv_transfer_params.get("remote_dcp_size"), chunk_finish=chunk_finish, + remote_cache_tokens=remote_cache_tokens, + local_computed_tokens=local_computed_tokens, + prompt_len=prompt_len, + local_transed_tokens=local_transed_tokens, ) @@ -605,7 +630,8 @@ class MooncakeLayerwiseConnectorScheduler: ) if params is not None and params.get("do_remote_prefill"): - local_block_ids = blocks.get_unhashed_block_ids() if num_external_tokens > 0 else [] + local_block_ids = (blocks.get_block_ids()[0]) if num_external_tokens > 0 else [] + remote_cached_tokens = request.num_computed_tokens # Get unhashed blocks to pull from remote. logger.debug( f"MooncakeLayerwiseConnector update_state_after_alloc: add {request.request_id} to need recv queue" @@ -632,6 +658,10 @@ class MooncakeLayerwiseConnectorScheduler: remote_engine_id=self.engine_id, remote_host=self.side_channel_host, remote_port=self.side_channel_port, + remote_tp_size=self.vllm_config.parallel_config.tensor_parallel_size, + remote_pcp_size=self.vllm_config.parallel_config.prefill_context_parallel_size, + remote_dcp_size=self.vllm_config.parallel_config.decode_context_parallel_size, + remote_cached_tokens=remote_cached_tokens, ) future = self.executor.submit( @@ -658,8 +688,6 @@ class MooncakeLayerwiseConnectorScheduler: local_computed_tokens = 0 self._reqs_need_send_layerwise[request.request_id] = SendReqInfo( local_block_ids=local_block_ids, - remote_block_ids=remote_block_ids, - remote_cache_tokens=remote_cache_tokens, local_transferred_tokens=local_transferred_tokens, local_computed_tokens=local_computed_tokens, request=request, @@ -691,11 +719,9 @@ class MooncakeLayerwiseConnectorScheduler: cached_reqs = scheduler_output.scheduled_cached_reqs new_reqs = scheduler_output.scheduled_new_reqs scheduled_spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens - # update local block ids for req_id, new_blocks in zip(cached_reqs.req_ids, cached_reqs.new_block_ids): if req_id in self._reqs_need_send_layerwise and new_blocks is not None: self._reqs_need_send_layerwise[req_id].extend_local_block_ids(new_blocks[0]) - computed_tokens = dict( list(zip(cached_reqs.req_ids, cached_reqs.num_computed_tokens)) + [(x.req_id, x.num_computed_tokens) for x in new_reqs] @@ -703,6 +729,10 @@ class MooncakeLayerwiseConnectorScheduler: for req_id, scheduled_tokens in scheduler_output.num_scheduled_tokens.items(): if req_id in self._reqs_need_send_layerwise: send_req_info = self._reqs_need_send_layerwise[req_id] + # update local transferred tokens + send_req_info.update_transferred_tokens( + round_down(send_req_info.local_computed_tokens, self.block_size) + ) # update local computed tokens, not transfer spec decode tokens spec_decode_tokens = ( len(scheduled_spec_decode_tokens[req_id]) if (req_id in scheduled_spec_decode_tokens) else 0 @@ -714,56 +744,36 @@ class MooncakeLayerwiseConnectorScheduler: def add_tranfer_task(req_id, send_req_info: SendReqInfo, chunk_finish=False): ( local_block_ids, - remote_block_ids, - remote_cache_tokens, - local_transferred_tokens, + local_transed_tokens, local_computed_tokens, request, ) = send_req_info.unpack() - local_trans_block_ids = local_block_ids[ - (local_transferred_tokens // self.block_size) : (local_computed_tokens // self.block_size) - ] - remote_trans_block_ids = remote_block_ids[ - ((local_transferred_tokens - remote_cache_tokens) // self.block_size) : ( - (local_computed_tokens - remote_cache_tokens) // self.block_size - ) - ] - request.kv_transfer_params["remote_block_ids"] = remote_trans_block_ids - assert len(local_trans_block_ids) == len(remote_trans_block_ids), ( - f"len of local trans block ids : {len(local_trans_block_ids)} not equal to " - f"the len of remote trans block ids : {len(remote_trans_block_ids)}" - ) - adjusted_tokens = ( - local_computed_tokens - (self.block_size - 1) if chunk_finish else local_computed_tokens - ) - logger.info( - f"MooncakeLayerwiseConnector scheduler add transfer task: " - f"{req_id=} {local_block_ids=} {remote_block_ids=} " - f"{local_trans_block_ids=} {remote_trans_block_ids=} " - f"local_computed_tokens={adjusted_tokens} " - f"request.all_token_ids={len(request.all_token_ids)}" - ) meta.add_new_req( request_id=req_id, - local_block_ids=local_trans_block_ids, + local_block_ids=local_block_ids, kv_transfer_params=request.kv_transfer_params, token_ids=[], chunk_finish=chunk_finish, + remote_cache_tokens=request.kv_transfer_params.get("remote_cached_tokens"), + prompt_len=len(request.all_token_ids), + local_computed_tokens=local_computed_tokens, + local_transed_tokens=local_transed_tokens, + ) + logger.debug( + f"MooncakeLayerwiseConnector build_connector_meta: {req_id=}" + f"prompt_len={len(request.all_token_ids)} {local_computed_tokens=}" + f"{local_transed_tokens=}" + f"remote_cache_tokens={request.kv_transfer_params.get('remote_cached_tokens')}" + f"{chunk_finish=} {local_block_ids=}" + f"remote_block_ids={request.kv_transfer_params.get('remote_block_ids')}" ) - # update local_transferred_tokens - local_transferred_tokens = (local_computed_tokens // self.block_size) * self.block_size - send_req_info.update_transferred_tokens(local_transferred_tokens) - # no chunk or last chunk - if send_req_info.local_computed_tokens >= len(send_req_info.request.all_token_ids): - send_req_info.update_computed_tokens(send_req_info.local_computed_tokens + self.block_size - 1) - add_tranfer_task(req_id, send_req_info, chunk_finish=True) + # whether chunk finish + chunk_finish = send_req_info.local_computed_tokens >= len(send_req_info.request.all_token_ids) + + add_tranfer_task(req_id, send_req_info, chunk_finish=chunk_finish) + if chunk_finish: self._reqs_need_send_layerwise.pop(req_id) - # chunk - elif (send_req_info.local_computed_tokens // self.block_size) - ( - send_req_info.local_transferred_tokens // self.block_size - ) > 0: - add_tranfer_task(req_id, send_req_info) return meta def _access_metaserver(self, url, message): @@ -796,13 +806,7 @@ class MooncakeLayerwiseConnectorWorker: """Implementation of Worker side methods""" def __init__(self, vllm_config: VllmConfig, engine_id: str): - self._get_prefill_decode_size(vllm_config) os.environ["ASCEND_TRANSFER_TIMEOUT"] = str(get_transfer_timeout_value()) - if self._prefill_tp_size < self._decode_tp_size: - raise ValueError( - f"prefill_tp_size: {self._prefill_tp_size} must be greater than" - f" or equal to the decode_tp_size: {self._decode_tp_size}" - ) if TransferEngine is None: raise RuntimeError("mooncake is not available") @@ -814,11 +818,20 @@ class MooncakeLayerwiseConnectorWorker: self.engine_id = engine_id self.tp_rank = get_tensor_model_parallel_rank() self.tp_size = vllm_config.parallel_config.tensor_parallel_size + self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size + self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0 + self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size + self.dcp_rank = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0 self.tp_group = get_tp_group() + self._decode_tp_size: int | None = None self.kv_caches: dict[str, torch.Tensor] = {} self.side_channel_host = get_ip() self.total_layers = vllm_config.model_config.get_num_layers(vllm_config.parallel_config) self.use_mla = self.vllm_config.model_config.use_mla + if self.use_mla: + self.total_num_kv_heads = 1 + else: + self.total_num_kv_heads = self.vllm_config.model_config.get_total_num_kv_heads() # Handshake base port self.side_channel_port = ( @@ -863,23 +876,6 @@ class MooncakeLayerwiseConnectorWorker: self.k_buffer: torch.Tensor | None = None self.v_buffer: torch.Tensor | None = None - def _get_prefill_decode_size(self, vllm_config: VllmConfig): - # get prefill tp and dp size from extra config - prefill_parallel_config: dict[str, Any] = vllm_config.kv_transfer_config.get_from_extra_config("prefill", {}) - - assert "tp_size" in prefill_parallel_config - self._prefill_tp_size = prefill_parallel_config["tp_size"] - - assert "dp_size" in prefill_parallel_config - self._prefill_dp_size = prefill_parallel_config["dp_size"] - - # get decode tp and dp size from extra config - decode_parallel_config: dict[str, Any] = vllm_config.kv_transfer_config.get_from_extra_config("decode", {}) - assert "tp_size" in decode_parallel_config - self._decode_tp_size = decode_parallel_config["tp_size"] - assert "dp_size" in decode_parallel_config - self._decode_dp_size = decode_parallel_config["dp_size"] - def create_kv_buffer(self, first_kv_cache): if self.pd_head_ratio > 1: # regesit kv buffer for tp inequal @@ -977,8 +973,6 @@ class MooncakeLayerwiseConnectorWorker: kv_cache_base_addr=self.kv_caches_base_addr, use_mla=self.use_mla, block_len=self.block_len, - decode_tp_size=self._decode_tp_size, - first_kv_cache=first_kv_cache, k_buffer=self.k_buffer, v_buffer=self.v_buffer, resharding_stream=self.resharding_stream, @@ -1009,9 +1003,120 @@ class MooncakeLayerwiseConnectorWorker: else set() ) if len(done_recving) > 0: - logger.info("Number of completed KV cache recv requests: %d, receive requests: %d", 0, len(done_recving)) + logger.info( + f"Number of completed KV cache recv requests: {len(done_recving)}, receive requests: {done_recving}" + ) return set(), done_recving + # {(ip, port)]: {local_block_ids: [], remote_block_ids: {}}} + def _get_kv_split_metadata(self, req_meta, req_idx, req_id): + remote_pcp_size = req_meta.remote_pcp_size + remote_dcp_size = req_meta.remote_dcp_size + remote_tp_size = req_meta.remote_tp_size + remote_hosts = [req_meta.remote_host] + remote_port = req_meta.remote_port + local_transed_tokens = max(req_meta.remote_cache_tokens, req_meta.local_transed_tokens) + # local_transed_tokens tokens that have already been transmitted on the local side + local_computed_tokens = req_meta.local_computed_tokens + prompt_len = req_meta.prompt_len + p_parallel_info = parallel_info( + tp_size=self.tp_size, + pcp_size=self.pcp_size, + dcp_size=self.dcp_size, + pd_head_ratio=self.pd_head_ratio, + use_mla=self.use_mla, + ) + d_parallel_info = parallel_info( + tp_size=remote_tp_size, + pcp_size=remote_pcp_size, + dcp_size=remote_dcp_size, + pd_head_ratio=self.pd_head_ratio, + use_mla=self.use_mla, + ) + cp_size = self.pcp_size * self.dcp_size + # to_trans_idx all tokens that have been processed up to the current step + if req_meta.chunk_finish: + to_trans_idx = math.ceil(local_computed_tokens / self.block_size) + else: + to_trans_idx = math.floor(local_computed_tokens / self.block_size) + prompt_block_size = math.ceil(prompt_len / self.block_size) + # + num_local_blocks = prompt_block_size // cp_size + int( + (prompt_block_size % cp_size) > (self.pcp_rank * self.dcp_size + self.dcp_rank) + ) + already_send_blocks = to_trans_idx // cp_size + int( + (to_trans_idx % cp_size) > (self.pcp_rank * self.dcp_size + self.dcp_rank) + ) + if num_local_blocks == already_send_blocks: + req_meta.chunk_finish = True + transed_idx = math.floor(local_transed_tokens / self.block_size) + + p_cp_group = get_cp_group(self.tp_size, self.total_num_kv_heads, self.dcp_size) + d_cp_group = get_cp_group(remote_tp_size, self.total_num_kv_heads, remote_dcp_size) + logger.debug(f"Compute cp group for P&D {req_id=} {p_cp_group=} {d_cp_group=}") + + cp_ratio = len(p_cp_group) // len(d_cp_group) + if cp_ratio == 0: + selected_p_cp_groups = p_cp_group + selected_d_cp_groups = d_cp_group + else: + x = req_idx % cp_ratio + start = x * len(d_cp_group) + selected_p_cp_groups = p_cp_group[start : (start + len(d_cp_group))] + selected_d_cp_groups = d_cp_group + assert len(selected_p_cp_groups) == len(selected_d_cp_groups) + + p_head_group_rank = (self.tp_rank - self.dcp_rank) // self.dcp_size + selected_p_cp_group = [] + selected_d_cp_group = [] + for idx, cp_group in enumerate(selected_p_cp_groups): + if p_head_group_rank in cp_group: # Check whether the rank is in selected_p_cp_groups + selected_p_cp_group = cp_group + selected_d_cp_group = selected_d_cp_groups[idx] + if len(selected_p_cp_group) == 0: + return {} + + logger.debug( + f"MooncakeLayerwiseConnector _get_kv_split_metadata {req_id=} " + f"P-side selected head_group cp group: {selected_p_cp_group}, " + f"D-side selected head_group cp group: {selected_d_cp_group}" + ) + + context_parallel_parameters_check( + remote_pcp_size, remote_dcp_size, p_parallel_info, d_parallel_info, self.total_num_kv_heads + ) + p_rank_block_mapping, d_block_rank_mapping, pd_head_mapping, d_trans_count_mapping = ( + get_local_remote_block_port_mappings( + to_trans_idx, + p_parallel_info, + d_parallel_info, + remote_hosts, + remote_port, + selected_p_cp_group, + selected_d_cp_group, + prompt_len, + self.block_size, + req_meta, + self.total_num_kv_heads, + req_id, + ) + ) + transfer_mappings = get_transfer_mappings( + p_rank_block_mapping, + d_block_rank_mapping, + pd_head_mapping, + d_trans_count_mapping, + req_meta, + p_parallel_info, + req_id, + transed_idx, + to_trans_idx, + self.tp_rank, + self.pcp_rank, + self.dcp_rank, + ) + return transfer_mappings + def start_load_kv(self, metadata: MooncakeLayerwiseConnectorMetadata): """Start loading KV blocks from remote engine.""" self.current_layer = 0 @@ -1023,31 +1128,29 @@ class MooncakeLayerwiseConnectorWorker: self.kv_recv_layer_thread.task_tracker[external_req_id] = 0 self.kv_recv_layer_thread.request_map[external_req_id] = req_id elif self.vllm_config.kv_transfer_config.is_kv_producer: - # select req to send - if self.use_mla or self.use_sparse: - num_need_send = self._decode_tp_size - else: - num_kv_head = self.vllm_config.model_config.hf_config.num_key_value_heads - if self.tp_size <= num_kv_head: - num_need_send = self.tp_size - else: - num_need_send = self._decode_tp_size if self._decode_tp_size >= num_kv_head else num_kv_head - num_replica_groups = self.tp_size // num_need_send if self.tp_size >= num_need_send else 1 - replica_group_idx = self.tp_rank % num_replica_groups - req_ids = sorted(list(metadata.requests.keys())) - selected_req_ids = [ - req_id for i, req_id in enumerate(req_ids) if i % num_replica_groups == replica_group_idx - ] - request_ids = list(metadata.requests.keys()) - for req_id in request_ids: - if req_id not in selected_req_ids: - metadata.requests.pop(req_id) + # update trans info + update_metadata = {} + for req_idx, (req_id, req_meta) in enumerate(metadata.requests.items()): + self._decode_tp_size = req_meta.remote_tp_size + transfer_mappings = self._get_kv_split_metadata(req_meta, req_idx, req_id) + assert len(transfer_mappings) <= 1, f"Not support add mutil transfer task for req_id:{req_id}" + update_req_meta = copy.deepcopy(req_meta) + for (host, port), block_dict in transfer_mappings.items(): + update_req_meta.remote_host = host + update_req_meta.remote_port = port + update_req_meta.local_block_ids = block_dict["local_block_ids"] + update_req_meta.remote_block_ids = block_dict["remote_block_ids"] + update_req_meta.trans_count = block_dict["trans_count"] + update_metadata[req_id] = update_req_meta + metadata.requests = {} + for req_id, req_meta in update_metadata.items(): + metadata.requests[req_id] = update_metadata[req_id] # update send task trans block info if self.pd_head_ratio != 1: send_task = metadata.send_task send_task.rearrange_block_ids = sorted( - {block_id for req_id in selected_req_ids for block_id in metadata.requests[req_id].local_block_ids} + {block_id for req_id in metadata.requests for block_id in metadata.requests[req_id].local_block_ids} ) device = self.k_buffer.device # type: ignore @@ -1070,7 +1173,7 @@ class MooncakeLayerwiseConnectorWorker: ) -> None: """MooncakeLayerwiseConnector does not save explicitly.""" if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys(): - # enable decode prefix cache + # get reshape and cache event if self.use_mla or self.use_sparse: reshape_cache_event = attn_metadata[layer_name].reshape_cache_event else: @@ -1156,59 +1259,48 @@ class MooncakeLayerwiseConnectorWorker: return sock def update_decoder_info(self, req_id, req_meta): - req_meta_update = copy.deepcopy(req_meta) - if self.use_mla or self.use_sparse: - pd_tp_ratio = self.tp_size // self._decode_tp_size - req_meta_update.remote_port = ( - req_meta_update.remote_port + (self.tp_rank // pd_tp_ratio) % self._decode_tp_size - ) - else: - req_meta_update.remote_port = ( - req_meta_update.remote_port + (self.tp_rank // self.pd_tp_ratio) % self._decode_tp_size - ) if ( - req_meta_update.remote_engine_id not in self.remote_kv_caches_base_addr - or req_meta_update.remote_port not in self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id] + req_meta.remote_engine_id not in self.remote_kv_caches_base_addr + or req_meta.remote_port not in self.remote_kv_caches_base_addr[req_meta.remote_engine_id] ): try: encoded_data = self.encoder.encode((GET_META_MSG, req_id)) - sock = self._get_remote_socket(req_meta_update.remote_host, req_meta_update.remote_port) - path = f"{req_meta_update.remote_host}:{req_meta_update.remote_port}" + sock = self._get_remote_socket(req_meta.remote_host, req_meta.remote_port) + path = f"{req_meta.remote_host}:{req_meta.remote_port}" ensure_zmq_send(sock, encoded_data, path) metadata_bytes = ensure_zmq_recv(sock, self.remote_poller, path) agent_meta = self.decoder.decode(metadata_bytes) except Exception as e: logger.error( - f"Query to port and kv base addr for request {req_id} from " - f"{req_meta_update.remote_host}:{req_meta_update.remote_port} fail with error: {e}" + f"Query to port and kv base addr for request {req_id}" + f"from {req_meta.remote_host}:{req_meta.remote_port}" + f"fail with error: {e}" ) - assert req_meta_update.remote_engine_id != self.engine_id, ( - f"Conflict engine id {req_meta_update.remote_engine_id} with local engine id {self.local_engine_id}." + assert req_meta.remote_engine_id != self.engine_id, ( + f"Conflict engine id {req_meta.remote_engine_id} with local engine id {self.local_engine_id}." ) - self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id][req_meta_update.remote_port] = ( + self.remote_kv_caches_base_addr[req_meta.remote_engine_id][req_meta.remote_port] = ( agent_meta.kv_caches_base_addr ) - self.remote_te_port[req_meta_update.remote_engine_id][req_meta_update.remote_port] = agent_meta.te_rpc_port + self.remote_te_port[req_meta.remote_engine_id][req_meta.remote_port] = agent_meta.te_rpc_port logger.info( - f"Query to port and kv base addr for request {req_id} from " - f"{req_meta_update.remote_host}:{req_meta_update.remote_port} success " - f"{agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}" + f"Query to port and kv base addr for request {req_id}" + f"from {req_meta.remote_host}:{req_meta.remote_port}" + f"success {agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}" ) if self.pd_head_ratio > 1: # for tp inequal, pre-create link to prevent alltoall out of memory - session_id = f"{req_meta_update.remote_host}:{agent_meta.te_rpc_port}" + session_id = f"{req_meta.remote_host}:{agent_meta.te_rpc_port}" ret = self.engine.batch_transfer_sync_write( session_id, [self.kv_caches_base_addr[0]], [agent_meta.kv_caches_base_addr[0]], [128] ) if ret < 0: logger.error(f"Mooncake transfer failed to create link to device {session_id}") - req_meta_update.remote_te_rpc_port = self.remote_te_port[req_meta_update.remote_engine_id][ - req_meta_update.remote_port + req_meta.remote_te_rpc_port = self.remote_te_port[req_meta.remote_engine_id][req_meta.remote_port] + req_meta.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[req_meta.remote_engine_id][ + req_meta.remote_port ] - req_meta_update.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id][ - req_meta_update.remote_port - ] - return req_meta_update + return req_meta def send_done_send_signal(self, req_id, req_meta): external_req_id = get_external_request_id(req_id) @@ -1221,7 +1313,7 @@ class MooncakeLayerwiseConnectorWorker: try: path = make_zmq_path("tcp", req_meta.remote_host, req_meta.remote_port) msg_encoder = msgspec.msgpack.Encoder() - encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id)) + encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id, req_meta.trans_count)) with zmq_ctx(zmq.REQ, path) as sock: # type: ignore ensure_zmq_send(sock, encoded_data, f"{req_meta.remote_host}:{req_meta.remote_port}") ack = sock.recv() diff --git a/vllm_ascend/distributed/kv_transfer/utils/utils.py b/vllm_ascend/distributed/kv_transfer/utils/utils.py index 4886f4ac..19ef0ed0 100644 --- a/vllm_ascend/distributed/kv_transfer/utils/utils.py +++ b/vllm_ascend/distributed/kv_transfer/utils/utils.py @@ -1,7 +1,12 @@ +import math import os +from collections import defaultdict +from dataclasses import dataclass +from typing import Any import torch import torch.distributed as dist +from vllm.logger import logger from vllm_ascend.distributed.parallel_state import get_p_tp_group @@ -50,3 +55,241 @@ def get_transfer_timeout_value(): hccl_rdma_timeout = int(os.getenv("HCCL_RDMA_TIMEOUT", "20")) # type: ignore hccl_rdma_retry_cnt = int(os.getenv("HCCL_RDMA_RETRY_CNT", "7")) # type: ignore return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 + 3000) + + +@dataclass +class parallel_info: + tp_size: int + pcp_size: int + dcp_size: int + use_mla: bool + pd_head_ratio: int + + +def get_cp_group(tp: int, heads: int, dcp: int): + # Partition the second dimension of [pcp][head_group][dcp] to obtain a complete head group + # head_group is all blocks for request in the same head + # tp8 dcp2 heads4 return[[0,1,2,3]] + # tp8 dcp1 heads4 return[[0,2,4,6],[1,3,5,7]] + step = tp // heads + if step == 0: + return [[i for i in range(tp // dcp)]] + else: + return [ + set([k // dcp for h in range(heads) for k in range(h * step + i * dcp, h * step + (i + 1) * dcp)]) + for i in range(step // dcp) + ] + + +def context_parallel_parameters_check( + remote_pcp_size: int, + remote_dcp_size: int, + p_parallel_info: parallel_info, + d_parallel_info: parallel_info, + total_num_kv_heads: int, +): + # Check whether the pcp–dcp ratio is supported + assert (p_parallel_info.pcp_size * p_parallel_info.dcp_size) % (remote_pcp_size * remote_dcp_size) == 0 + if not p_parallel_info.use_mla: + p_node_heads_per_rank = math.ceil(total_num_kv_heads / p_parallel_info.tp_size) + d_node_heads_per_rank = math.ceil(total_num_kv_heads / d_parallel_info.dcp_size) + assert d_node_heads_per_rank % p_node_heads_per_rank == 0 + + +def get_tp_rank_head_mapping(num_key_value_heads: int, tp_size: int): + # Get the head_idx corresponding to the tp_rank, {tp_rank:[head_indx]} + mapping = {} + if tp_size <= num_key_value_heads: + if num_key_value_heads % tp_size != 0: + raise ValueError(f"Number of heads ({num_key_value_heads}) cannot be evenly divided by TP ({tp_size}).") + + heads_per_rank = num_key_value_heads // tp_size + + for rank in range(tp_size): + start_idx = rank * heads_per_rank + end_idx = start_idx + heads_per_rank + mapping[rank] = list(range(start_idx, end_idx)) + else: + if tp_size % num_key_value_heads != 0: + raise ValueError(f"Number of heads ({num_key_value_heads}) cannot be evenly divided by TP ({tp_size}).") + ranks_per_head = tp_size // num_key_value_heads + for rank in range(tp_size): + head_idx = rank // ranks_per_head + mapping[rank] = [head_idx] + return mapping + + +def get_head_group_mapping(num_key_value_heads: int, tp_size: int, num_groups: int, select_cp_group: list[int]): + # Get the mapping dictionary, where the key is head_group_rank and the value is head_idx + if tp_size % num_groups != 0: + raise ValueError( + f"Total number of devices ({tp_size}) cannot be divided by the number of groups ({num_groups})." + ) + ranks_per_group = tp_size // num_groups + tp_mapping = get_tp_rank_head_mapping(num_key_value_heads, tp_size) + group_mapping = {} + for group_rank in range(num_groups): + if group_rank in select_cp_group: + start_rank = group_rank * ranks_per_group + end_rank = start_rank + ranks_per_group + heads_set = set() + + for rank in range(start_rank, end_rank): + heads_set.update(tp_mapping[rank]) + group_mapping[group_rank] = sorted(list(heads_set)) + return group_mapping + + +def get_local_remote_block_port_mappings( + to_trans_idx: int, + p_parallel_info: parallel_info, + d_parallel_info: parallel_info, + d_hosts: list[str], + d_port: int, + selected_p_cp_group: list[int], + selected_d_cp_group: list[int], + prompt_len: int, + block_size: int, + req_meta, + total_num_kv_heads: int, + req_id: str, +): + p_head_group_size = p_parallel_info.tp_size // p_parallel_info.dcp_size + d_head_group_size = d_parallel_info.tp_size // d_parallel_info.dcp_size + world_size = d_parallel_info.pcp_size * d_head_group_size * d_parallel_info.dcp_size + # Compute which logic_block_idx corresponds to each tp_rank + p_rank_block_mapping: list[list[list[list[int]]]] = [ + [[[] for _ in range(p_parallel_info.dcp_size)] for _ in range(p_head_group_size)] + for _ in range(p_parallel_info.pcp_size) + ] + for logic_block_idx in range(to_trans_idx): + pcp_rank = (logic_block_idx // p_parallel_info.dcp_size) % p_parallel_info.pcp_size + dcp_rank = logic_block_idx % p_parallel_info.dcp_size + for p_head_group_rank in range(p_head_group_size): + if p_head_group_rank in selected_p_cp_group: + p_rank_block_mapping[pcp_rank][p_head_group_rank][dcp_rank].append(logic_block_idx) + + # Find the remote device that holds the logic_block_idx + d_block_rank_mapping: dict[int, dict[int, dict[str, Any]]] = defaultdict(lambda: defaultdict(dict)) + for logic_block_idx in range(to_trans_idx): + pcp_rank = (logic_block_idx // d_parallel_info.dcp_size) % d_parallel_info.pcp_size + for d_head_group_rank in range(d_head_group_size): + if d_head_group_rank in selected_d_cp_group: + dcp_rank = logic_block_idx % d_parallel_info.dcp_size + world_rank = ( + pcp_rank * d_head_group_size * d_parallel_info.dcp_size + + d_head_group_rank * d_parallel_info.dcp_size + + dcp_rank + ) + world_size = d_parallel_info.pcp_size * d_head_group_size * d_parallel_info.dcp_size + host = d_hosts[(len(d_hosts) * world_rank) // world_size] + port = d_port + world_rank + block_idx = (logic_block_idx - (pcp_rank * d_parallel_info.pcp_size + dcp_rank)) // ( + d_parallel_info.pcp_size * d_parallel_info.dcp_size + ) + d_block_rank_mapping[logic_block_idx][d_head_group_rank] = { + "pcp_rank": pcp_rank, + "dcp_rank": dcp_rank, + "host": host, + "port": port, + "block_idx": block_idx, + } + # Get how many times each device should receive done_single for this request + d_trans_count_mapping = {} + trans_block_size = math.ceil(prompt_len / block_size) # Total number of blocks + transed_block_size = math.ceil(req_meta.remote_cache_tokens / block_size) # Number of prefix cache hit blocks + d_cp_size = d_parallel_info.pcp_size * d_parallel_info.dcp_size + for d_pcp_rank in range(d_parallel_info.pcp_size): + for d_head_group_rank in range(d_head_group_size): + for d_dcp_rank in range(d_parallel_info.dcp_size): + if trans_block_size >= (p_parallel_info.pcp_size * p_parallel_info.dcp_size): + trans_count = (p_parallel_info.pcp_size * p_parallel_info.dcp_size) // d_cp_size + else: + current_rank_idx = d_pcp_rank * d_parallel_info.dcp_size + d_dcp_rank + total_global_blocks = transed_block_size + trans_block_size + + target_total_count = total_global_blocks // d_cp_size + if current_rank_idx < (total_global_blocks % d_cp_size): + target_total_count += 1 + + prev_processed_count = transed_block_size // d_cp_size + if current_rank_idx < (transed_block_size % d_cp_size): + prev_processed_count += 1 + + trans_count = target_total_count - prev_processed_count + world_rank = ( + d_pcp_rank * d_head_group_size * d_parallel_info.dcp_size + + d_head_group_rank * d_parallel_info.dcp_size + + d_dcp_rank + ) + host = d_hosts[(len(d_hosts) * world_rank) // world_size] + port = d_port + world_rank + d_trans_count_mapping[(host, port)] = trans_count * p_parallel_info.pd_head_ratio + + # Compute the mapping between local and remote head_group_rank + p_tp_rank_head_mapping = get_head_group_mapping( + total_num_kv_heads, p_parallel_info.tp_size, p_head_group_size, selected_p_cp_group + ) + d_tp_rank_head_mapping = get_head_group_mapping( + total_num_kv_heads, d_parallel_info.tp_size, d_head_group_size, selected_d_cp_group + ) + head_to_d_groups = defaultdict(set) + for d_rank, heads in d_tp_rank_head_mapping.items(): + for head in heads: + head_to_d_groups[head].add(d_rank) + pd_head_mapping = {} + for p_rank, p_heads in p_tp_rank_head_mapping.items(): + target_d_ranks = set() + for head in p_heads: + if head in head_to_d_groups: + target_d_ranks.update(head_to_d_groups[head]) + else: + logger.info(f"Warning: Head {head} exists in P but not in D mapping.") + pd_head_mapping[p_rank] = sorted(list(target_d_ranks)) + logger.debug( + f"MooncakeLayerwiseConnector _get_kv_split_metadata {req_id=} " + f"P-side logic_block to rank mapping: {p_rank_block_mapping}, " + f"D-side logic_block to rank mapping: {d_block_rank_mapping}, " + f"P&D head_group_rank mapping: {pd_head_mapping}" + ) + return p_rank_block_mapping, d_block_rank_mapping, pd_head_mapping, d_trans_count_mapping + + +def get_transfer_mappings( + p_rank_block_mapping: list[list[list[list[int]]]], + d_block_rank_mapping: dict[int, dict[int, dict[str, Any]]], + pd_head_mapping: dict[int, set], + d_trans_count_mapping: dict[tuple[str, int], int], + req_meta, + p_parallel_info: parallel_info, + req_id: str, + transed_idx: int, + to_trans_idx: int, + tp_rank: int, + pcp_rank: int, + dcp_rank: int, +): + transfer_mappings: dict[tuple[str, int], dict[str, Any]] = {} + p_head_group_rank = (tp_rank - dcp_rank) // p_parallel_info.dcp_size + p_block_idxs: list[int] = p_rank_block_mapping[pcp_rank][p_head_group_rank][dcp_rank] + for p_block_idx, logic_block_idx in enumerate(p_block_idxs): + if logic_block_idx < transed_idx or logic_block_idx >= to_trans_idx: + continue + for d_head_group_rank in pd_head_mapping[p_head_group_rank]: + p_block_id = req_meta.local_block_ids[p_block_idx] + remote_host = d_block_rank_mapping[logic_block_idx][d_head_group_rank]["host"] + remote_port = d_block_rank_mapping[logic_block_idx][d_head_group_rank]["port"] + d_block_idx = d_block_rank_mapping[logic_block_idx][d_head_group_rank]["block_idx"] + d_block_id = req_meta.remote_block_ids[d_block_idx] + if (remote_host, remote_port) not in transfer_mappings: + transfer_mappings[(remote_host, remote_port)] = { + "local_block_ids": [], + "remote_block_ids": [], + "trans_count": 0, + } + transfer_mappings[(remote_host, remote_port)]["local_block_ids"].append(p_block_id) + transfer_mappings[(remote_host, remote_port)]["remote_block_ids"].append(d_block_id) + for (host, port), block_dict in transfer_mappings.items(): + block_dict["trans_count"] = d_trans_count_mapping[(host, port)] + logger.debug(f"MooncakeLayerwiseConnector Request {req_id} transfer tasks: {transfer_mappings}") + return transfer_mappings