[P/D][PCP] mooncake layerwise support pcp function (#6627)

### What this PR does / why we need it?
mooncake layerwise support pcp function
PCP (Prefill Context Parallelism) Support: Introduced explicit support
for Prefill Context Parallelism (PCP) and Decode Context Parallelism
(DCP) in the Mooncake layerwise KV cache transfer mechanism, allowing
for more granular control and awareness of parallel configurations
during data transfer.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
By ci

- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd

---------

Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
This commit is contained in:
wangxiaoteng888
2026-02-12 11:02:25 +08:00
committed by GitHub
parent 8b23554741
commit b881fab416
7 changed files with 551 additions and 223 deletions

View File

@@ -232,15 +232,17 @@ class TestAscendAttentionCPImpl(TestBase):
self.assertEqual(value.shape[1], num_heads)
self.assertEqual(value.shape[2], head_size)
@patch('torch_npu.Event', create=True)
@patch('torch_npu._npu_reshape_and_cache')
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
def test_reshape_and_cache(self, mock_npu_reshape_and_cache):
def test_reshape_and_cache(self, mock_event_class, mock_npu_reshape_and_cache):
num_tokens = 4
block_num = 100
block_size = 128
num_heads = 1
head_size = 128
self.impl.head_size = head_size
self.impl.is_kv_producer = False
kv_cache = (torch.randn(block_num, block_size, num_heads, head_size),
torch.randn(block_num, block_size, num_heads, head_size))

View File

@@ -61,8 +61,6 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
6000], # 2 * total_layers
use_mla=True,
block_len=[1024, 2048],
decode_tp_size=1,
first_kv_cache=self.first_kv_cache,
k_buffer=self.fake_k_buffer,
v_buffer=self.fake_v_buffer,
resharding_stream=fake_resharding_stream,
@@ -70,6 +68,9 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
self.req_meta_base = ReqMeta(
local_block_ids=[5, 8],
remote_tp_size = 8,
remote_pcp_size = 1,
remote_dcp_size = 1,
token_ids=[1, 2, 3],
remote_block_ids=[10, 20],
remote_engine_id="remote_engine",
@@ -112,8 +113,6 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
kv_cache_base_addr=[1111, 2222, 3333, 4444],
use_mla=False,
block_len=[64],
decode_tp_size=1,
first_kv_cache=self.first_kv_cache,
k_buffer=self.fake_k_buffer,
v_buffer=self.fake_v_buffer,
resharding_stream=fake_resharding_stream,
@@ -242,12 +241,12 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase):
th.task_tracker["reqX"] = 0
th.request_map["reqX"] = "reqX"
th.update_task("reqX")
th.update_task("reqX", 2)
with th.lock:
self.assertIn("reqX", th.task_tracker)
self.assertNotIn("reqX", th.done_requests)
th.update_task("reqX")
th.update_task("reqX", 2)
with th.lock:
self.assertNotIn("reqX", th.task_tracker)
self.assertIn("reqX", th.done_requests)
@@ -284,7 +283,7 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase):
dec_inst = MagicMock()
dec_inst.decode.side_effect = [
(GET_META_MSG, ),
(DONE_SENDING_MSG, "reqA"),
(DONE_SENDING_MSG, "reqA", 1),
(b"weird_msg", ),
]
mock_Decoder.return_value = dec_inst
@@ -339,21 +338,11 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase):
finished = th.get_and_clear_finished_requests()
self.assertIn("reqA", finished)
@patch(
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger"
)
@patch(
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.get_ip",
return_value="127.0.0.1")
@patch(
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Decoder"
)
@patch(
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Encoder"
)
@patch(
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.zmq_ctx"
)
@patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger")
@patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.get_ip", return_value="127.0.0.1")
@patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Decoder")
@patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Encoder")
@patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.zmq_ctx")
def test_run_loop_pd_head_ratio_gt1_requires_multiple_done(
self, mock_zmq_ctx, mock_Encoder, mock_Decoder, _mock_get_ip,
_mock_logger):
@@ -364,8 +353,8 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase):
dec_inst = MagicMock()
dec_inst.decode.side_effect = [
(DONE_SENDING_MSG, "reqB"),
(DONE_SENDING_MSG, "reqB"),
(DONE_SENDING_MSG, "reqB", 2),
(DONE_SENDING_MSG, "reqB", 2),
]
mock_Decoder.return_value = dec_inst
@@ -373,25 +362,26 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase):
sock.recv_multipart.side_effect = [
[b"ID", b"PAY1"],
[b"ID", b"PAY2"],
SystemExit,
SystemExit, # 退出循环
]
cm = MagicMock()
cm.__enter__.return_value = sock
mock_zmq_ctx.return_value = cm
th = KVCacheRecvingLayerThread(tp_rank=0,
side_channel_port=5555,
tp_size=2,
pd_head_ratio=2,
local_engine_id="engineY",
metadata=self.meta,
ready_event=self.ready_event)
th = KVCacheRecvingLayerThread(
tp_rank=0,
side_channel_port=5555,
tp_size=2,
pd_head_ratio=2,
local_engine_id="engineY",
metadata=self.meta,
ready_event=self.ready_event
)
with th.lock:
th.task_tracker["reqB"] = 0
th.request_map["reqB"] = "reqB"
with self.assertRaises(SystemExit):
th.run()
finished = th.get_and_clear_finished_requests()
self.assertIn("reqB", finished)
@@ -441,6 +431,7 @@ class MockRequest:
self.kv_transfer_params = kv_transfer_params or {}
self.status = status or "running"
self.output_token_ids = [101, 102]
self.num_computed_tokens = 0
self.all_token_ids = list(self.prompt_token_ids)
@@ -565,7 +556,8 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
req = MockRequest("req_u1",
prompt_token_ids=list(range(24)),
kv_transfer_params={"do_remote_prefill": True})
blocks = _MockBlocks(unhashed=[4, 5, 6])
req.num_computed_tokens = 0
blocks = _MockBlocks(unhashed=[4, 5, 6], block_ids_tuple=([4, 5, 6], ))
self.scheduler.update_state_after_alloc(req,
blocks,
@@ -592,7 +584,6 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
info = self.scheduler._reqs_need_send_layerwise["req_u2"]
self.assertEqual(info.local_block_ids, [7, 8, 9])
self.assertIs(info.request, req)
self.assertEqual(info.remote_block_ids, [])
def test_build_connector_meta_consumes_reqs_need_recv_and_clears(self):
self.scheduler.vllm_config.kv_transfer_config.is_kv_consumer = True
@@ -663,12 +654,13 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
send_req_info.update_computed_tokens = MagicMock()
send_req_info.update_transferred_tokens = MagicMock()
send_req_info.unpack = MagicMock(
return_value=(send_req_info.local_block_ids,
send_req_info.remote_block_ids,
send_req_info.remote_cache_tokens,
send_req_info.local_transferred_tokens,
send_req_info.local_computed_tokens,
send_req_info.request))
return_value=(
send_req_info.local_block_ids,
send_req_info.local_transferred_tokens,
send_req_info.local_computed_tokens,
send_req_info.request
)
)
self.scheduler._reqs_need_send_layerwise["req_b3"] = send_req_info
out = _MockSchedulerOutput(
@@ -920,6 +912,11 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase):
self.vllm_config = MockVllmConfig()
self.engine_id = "test_engine"
self.kv_caches = {"layer1": (MagicMock(), MagicMock())}
self.vllm_config.parallel_config.tensor_parallel_size = 1
self.vllm_config.parallel_config.prefill_context_parallel_size = 1
self.vllm_config.parallel_config.decode_context_parallel_size = 1
self.vllm_config.parallel_config.data_parallel_rank = 0
self.vllm_config.kv_transfer_config.kv_port = 1234
def tearDown(self):
for p in self.patches:
@@ -956,4 +953,4 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase):
self.engine_id)
worker.register_kv_caches(mla_caches)
self.assertTrue(worker.use_mla)
self.assertEqual(len(worker.block_len), 2)
self.assertEqual(len(worker.block_len), 2)

View File

@@ -8,14 +8,11 @@ from typing import Any, Optional
import torch
from vllm import SamplingParams
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
ModelConfig, SchedulerConfig, VllmConfig)
from vllm.config import CacheConfig, DeviceConfig, KVTransferConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.utils.hashing import sha256
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash)
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request
from vllm.v1.structured_output import StructuredOutputManager
@@ -37,14 +34,10 @@ def assert_scheduler_empty(scheduler: Scheduler):
assert len(scheduler.encoder_cache_manager.cached) == 0
# KVCache Manager.
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block) == 0
num_free_blocks = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
assert num_free_blocks == (
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].num_cached_block) == 0
num_free_blocks = scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks
assert num_free_blocks == (scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
# NOTE(rob): just the ref count on blocks will be 0. The hash
# value, etc will remain since we lazily evict for prefix cache.
@@ -63,8 +56,7 @@ def create_vllm_config(
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_num_batched_tokens,
)
fake_weight_path = os.path.join(os.path.dirname(__file__), "..",
"fake_weight")
fake_weight_path = os.path.join(os.path.dirname(__file__), "..", "fake_weight")
model_config = ModelConfig(
model=fake_weight_path,
skip_tokenizer_init=True,
@@ -77,14 +69,14 @@ def create_vllm_config(
cache_dtype="auto",
enable_prefix_caching=True,
)
kv_transfer_config = KVTransferConfig(
kv_connector="MooncakeConnectorV1",
kv_role="kv_both")
return VllmConfig(scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
device_config=DeviceConfig("cpu"))
kv_transfer_config = KVTransferConfig(kv_connector="MooncakeConnectorV1", kv_role="kv_both")
return VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
device_config=DeviceConfig("cpu"),
)
def create_scheduler(
@@ -96,11 +88,7 @@ def create_scheduler(
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float16,
False, False))
],
kv_cache_groups=[KVCacheGroupSpec(["layer"], FullAttentionSpec(block_size, 1, 1, torch.float16, False, False))],
)
vllm_config.cache_config.num_gpu_blocks = num_blocks
@@ -138,19 +126,19 @@ def create_request(
if do_remote_decode:
assert not do_remote_prefill
kv_transfer_params = dict(do_remote_prefill=False,
do_remote_decode=True)
kv_transfer_params = dict(do_remote_prefill=False, do_remote_decode=True)
elif do_remote_prefill:
kv_transfer_params = dict(do_remote_prefill=True,
do_remote_decode=False,
remote_engine_id="my-engine-id",
remote_block_ids=list(
range(num_remote_blocks)),
remote_host="my-host",
remote_port=1234,
remote_tp_size=1,
remote_pcp_size=1,
remote_dcp_size=1)
kv_transfer_params = dict(
do_remote_prefill=True,
do_remote_decode=False,
remote_engine_id="my-engine-id",
remote_block_ids=list(range(num_remote_blocks)),
remote_host="my-host",
remote_port=1234,
remote_tp_size=1,
remote_pcp_size=1,
remote_dcp_size=1,
)
max_tokens = 1 if do_remote_decode else max_tokens
sampling_params = SamplingParams(max_tokens=max_tokens)
@@ -190,10 +178,9 @@ def create_model_runner_output(
# Make output data structure.
extra_args = {}
from vllm.v1.worker.kv_connector_model_runner_mixin import \
KVConnectorOutput # type: ignore # noqa
kv_connector_output = KVConnectorOutput(finished_sending=finished_sending,
finished_recving=finished_recving)
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput # type: ignore # noqa
kv_connector_output = KVConnectorOutput(finished_sending=finished_sending, finished_recving=finished_recving)
extra_args = {"kv_connector_output": kv_connector_output}
model_runner_output = ModelRunnerOutput(

View File

@@ -743,6 +743,8 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
has_prefill = attn_metadata.num_prefills > 0
if len(kv_cache) > 1:
if self.is_kv_producer:
attn_metadata.reshape_cache_event = torch.npu.Event()
if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
@@ -778,7 +780,8 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
value_cache=self.value_cache,
slot_indices=slot_mapping,
)
if self.is_kv_producer:
attn_metadata.reshape_cache_event.record()
return key, value
def _gather_global_context_output(self, local_context_attn_output):

View File

@@ -414,9 +414,13 @@ class AscendMlaCPImpl(AscendMLAImpl):
kv_c_normed, k_pe = prefill_k_c_normed, prefill_k_pe
prefill_k_c_normed = prefill_k_c_normed.squeeze()
slot_mapping = attn_metadata.slot_mapping[self.pcp_size * num_decode_tokens :]
if self.is_kv_producer:
attn_metadata.reshape_cache_event = torch.npu.Event()
torch_npu._npu_reshape_and_cache(
key=kv_c_normed, value=k_pe, key_cache=kv_cache[0], value_cache=kv_cache[1], slot_indices=slot_mapping
)
if self.is_kv_producer:
attn_metadata.reshape_cache_event.record()
prefill_k_nope, prefill_value = (
self.kv_b_proj(prefill_k_c_normed)[0]
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)

View File

@@ -23,9 +23,16 @@ import torch_npu
import zmq
from mooncake.engine import TransferEngine # type: ignore
from vllm.config import VllmConfig
from vllm.distributed import get_pcp_group
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank, get_tp_group, get_world_group
from vllm.distributed.parallel_state import (
get_decode_context_model_parallel_rank,
get_tensor_model_parallel_rank,
get_tp_group,
get_world_group,
)
from vllm.logger import logger
from vllm.utils.math_utils import round_down
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
@@ -35,8 +42,13 @@ from vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector import GET_ME
from vllm_ascend.distributed.kv_transfer.utils.mooncake_transfer_engine import global_te
from vllm_ascend.distributed.kv_transfer.utils.utils import (
align_memory,
context_parallel_parameters_check,
get_cp_group,
get_local_remote_block_port_mappings,
get_transfer_mappings,
get_transfer_timeout_value,
kv_alltoall_and_rearrange,
parallel_info,
)
from vllm_ascend.utils import npu_stream_switch
@@ -68,7 +80,15 @@ class ReqMeta:
remote_te_rpc_port: int | None
remote_kv_caches_base_addr: list[int] | None
metaserver: str | None
chunk_finish: bool | None
remote_tp_size: int | None
remote_pcp_size: int | None
remote_dcp_size: int | None
chunk_finish: bool = False
prompt_len: int = 0
trans_count: int = 0
remote_cache_tokens: int = 0
local_computed_tokens: int = 0
local_transed_tokens: int = 0
@dataclass
@@ -100,8 +120,6 @@ class TransferMeta:
@dataclass
class SendReqInfo:
local_block_ids: list[int]
remote_block_ids: list[int]
remote_cache_tokens: int
local_transferred_tokens: int
local_computed_tokens: int
request: "Request"
@@ -121,8 +139,6 @@ class SendReqInfo:
def unpack(self):
return (
self.local_block_ids,
self.remote_block_ids,
self.remote_cache_tokens,
self.local_transferred_tokens,
self.local_computed_tokens,
self.request,
@@ -161,8 +177,6 @@ class KVCacheSendingLayerThread(threading.Thread):
kv_cache_base_addr: list[int],
use_mla: bool,
block_len: list[int],
decode_tp_size: int,
first_kv_cache: torch.Tensor,
k_buffer: torch.Tensor,
v_buffer: torch.Tensor,
resharding_stream: torch.npu.Stream,
@@ -178,7 +192,6 @@ class KVCacheSendingLayerThread(threading.Thread):
self.use_mla = use_mla
self.use_sparse = len(block_len) == 3
self.block_len = block_len
self._decode_tp_size = decode_tp_size
self.resharding_stream = resharding_stream
self.current_layer = -1
@@ -373,10 +386,10 @@ class KVCacheRecvingLayerThread(threading.Thread):
self.done_requests = set()
return finished_requests
def update_task(self, req_id):
def update_task(self, req_id, trans_count):
with self.lock:
self.task_tracker[req_id] += 1
if self.task_tracker[req_id] == self.pd_head_ratio:
if self.task_tracker[req_id] == trans_count:
self.task_tracker.pop(req_id)
self.done_requests.add(self.request_map[req_id])
self.request_map.pop(req_id)
@@ -411,7 +424,8 @@ class KVCacheRecvingLayerThread(threading.Thread):
elif msg[0] == DONE_SENDING_MSG:
logger.debug("Got DONE_RECVING_MSG for request %s", msg[1])
request_id = msg[1]
self.update_task(request_id)
trans_count = msg[2]
self.update_task(request_id, trans_count)
sock.send_multipart((identity, b"", b"ACK"))
else:
logger.error("Connection listener got unexpected message %s", msg)
@@ -431,6 +445,10 @@ class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata):
kv_transfer_params: dict[str, Any],
token_ids: list[int] | None = None,
chunk_finish: bool = False,
prompt_len: int = 0,
remote_cache_tokens: int = 0,
local_computed_tokens: int = 0,
local_transed_tokens: int = 0,
):
self.requests[request_id] = ReqMeta(
token_ids=token_ids or [],
@@ -442,7 +460,14 @@ class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata):
remote_te_rpc_port=kv_transfer_params.get("remote_te_rpc_port"),
remote_kv_caches_base_addr=kv_transfer_params.get("remote_kv_caches_base_addr"),
metaserver=kv_transfer_params.get("metaserver"),
remote_tp_size=kv_transfer_params.get("remote_tp_size"),
remote_pcp_size=kv_transfer_params.get("remote_pcp_size"),
remote_dcp_size=kv_transfer_params.get("remote_dcp_size"),
chunk_finish=chunk_finish,
remote_cache_tokens=remote_cache_tokens,
local_computed_tokens=local_computed_tokens,
prompt_len=prompt_len,
local_transed_tokens=local_transed_tokens,
)
@@ -605,7 +630,8 @@ class MooncakeLayerwiseConnectorScheduler:
)
if params is not None and params.get("do_remote_prefill"):
local_block_ids = blocks.get_unhashed_block_ids() if num_external_tokens > 0 else []
local_block_ids = (blocks.get_block_ids()[0]) if num_external_tokens > 0 else []
remote_cached_tokens = request.num_computed_tokens
# Get unhashed blocks to pull from remote.
logger.debug(
f"MooncakeLayerwiseConnector update_state_after_alloc: add {request.request_id} to need recv queue"
@@ -632,6 +658,10 @@ class MooncakeLayerwiseConnectorScheduler:
remote_engine_id=self.engine_id,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
remote_tp_size=self.vllm_config.parallel_config.tensor_parallel_size,
remote_pcp_size=self.vllm_config.parallel_config.prefill_context_parallel_size,
remote_dcp_size=self.vllm_config.parallel_config.decode_context_parallel_size,
remote_cached_tokens=remote_cached_tokens,
)
future = self.executor.submit(
@@ -658,8 +688,6 @@ class MooncakeLayerwiseConnectorScheduler:
local_computed_tokens = 0
self._reqs_need_send_layerwise[request.request_id] = SendReqInfo(
local_block_ids=local_block_ids,
remote_block_ids=remote_block_ids,
remote_cache_tokens=remote_cache_tokens,
local_transferred_tokens=local_transferred_tokens,
local_computed_tokens=local_computed_tokens,
request=request,
@@ -691,11 +719,9 @@ class MooncakeLayerwiseConnectorScheduler:
cached_reqs = scheduler_output.scheduled_cached_reqs
new_reqs = scheduler_output.scheduled_new_reqs
scheduled_spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
# update local block ids
for req_id, new_blocks in zip(cached_reqs.req_ids, cached_reqs.new_block_ids):
if req_id in self._reqs_need_send_layerwise and new_blocks is not None:
self._reqs_need_send_layerwise[req_id].extend_local_block_ids(new_blocks[0])
computed_tokens = dict(
list(zip(cached_reqs.req_ids, cached_reqs.num_computed_tokens))
+ [(x.req_id, x.num_computed_tokens) for x in new_reqs]
@@ -703,6 +729,10 @@ class MooncakeLayerwiseConnectorScheduler:
for req_id, scheduled_tokens in scheduler_output.num_scheduled_tokens.items():
if req_id in self._reqs_need_send_layerwise:
send_req_info = self._reqs_need_send_layerwise[req_id]
# update local transferred tokens
send_req_info.update_transferred_tokens(
round_down(send_req_info.local_computed_tokens, self.block_size)
)
# update local computed tokens, not transfer spec decode tokens
spec_decode_tokens = (
len(scheduled_spec_decode_tokens[req_id]) if (req_id in scheduled_spec_decode_tokens) else 0
@@ -714,56 +744,36 @@ class MooncakeLayerwiseConnectorScheduler:
def add_tranfer_task(req_id, send_req_info: SendReqInfo, chunk_finish=False):
(
local_block_ids,
remote_block_ids,
remote_cache_tokens,
local_transferred_tokens,
local_transed_tokens,
local_computed_tokens,
request,
) = send_req_info.unpack()
local_trans_block_ids = local_block_ids[
(local_transferred_tokens // self.block_size) : (local_computed_tokens // self.block_size)
]
remote_trans_block_ids = remote_block_ids[
((local_transferred_tokens - remote_cache_tokens) // self.block_size) : (
(local_computed_tokens - remote_cache_tokens) // self.block_size
)
]
request.kv_transfer_params["remote_block_ids"] = remote_trans_block_ids
assert len(local_trans_block_ids) == len(remote_trans_block_ids), (
f"len of local trans block ids : {len(local_trans_block_ids)} not equal to "
f"the len of remote trans block ids : {len(remote_trans_block_ids)}"
)
adjusted_tokens = (
local_computed_tokens - (self.block_size - 1) if chunk_finish else local_computed_tokens
)
logger.info(
f"MooncakeLayerwiseConnector scheduler add transfer task: "
f"{req_id=} {local_block_ids=} {remote_block_ids=} "
f"{local_trans_block_ids=} {remote_trans_block_ids=} "
f"local_computed_tokens={adjusted_tokens} "
f"request.all_token_ids={len(request.all_token_ids)}"
)
meta.add_new_req(
request_id=req_id,
local_block_ids=local_trans_block_ids,
local_block_ids=local_block_ids,
kv_transfer_params=request.kv_transfer_params,
token_ids=[],
chunk_finish=chunk_finish,
remote_cache_tokens=request.kv_transfer_params.get("remote_cached_tokens"),
prompt_len=len(request.all_token_ids),
local_computed_tokens=local_computed_tokens,
local_transed_tokens=local_transed_tokens,
)
logger.debug(
f"MooncakeLayerwiseConnector build_connector_meta: {req_id=}"
f"prompt_len={len(request.all_token_ids)} {local_computed_tokens=}"
f"{local_transed_tokens=}"
f"remote_cache_tokens={request.kv_transfer_params.get('remote_cached_tokens')}"
f"{chunk_finish=} {local_block_ids=}"
f"remote_block_ids={request.kv_transfer_params.get('remote_block_ids')}"
)
# update local_transferred_tokens
local_transferred_tokens = (local_computed_tokens // self.block_size) * self.block_size
send_req_info.update_transferred_tokens(local_transferred_tokens)
# no chunk or last chunk
if send_req_info.local_computed_tokens >= len(send_req_info.request.all_token_ids):
send_req_info.update_computed_tokens(send_req_info.local_computed_tokens + self.block_size - 1)
add_tranfer_task(req_id, send_req_info, chunk_finish=True)
# whether chunk finish
chunk_finish = send_req_info.local_computed_tokens >= len(send_req_info.request.all_token_ids)
add_tranfer_task(req_id, send_req_info, chunk_finish=chunk_finish)
if chunk_finish:
self._reqs_need_send_layerwise.pop(req_id)
# chunk
elif (send_req_info.local_computed_tokens // self.block_size) - (
send_req_info.local_transferred_tokens // self.block_size
) > 0:
add_tranfer_task(req_id, send_req_info)
return meta
def _access_metaserver(self, url, message):
@@ -796,13 +806,7 @@ class MooncakeLayerwiseConnectorWorker:
"""Implementation of Worker side methods"""
def __init__(self, vllm_config: VllmConfig, engine_id: str):
self._get_prefill_decode_size(vllm_config)
os.environ["ASCEND_TRANSFER_TIMEOUT"] = str(get_transfer_timeout_value())
if self._prefill_tp_size < self._decode_tp_size:
raise ValueError(
f"prefill_tp_size: {self._prefill_tp_size} must be greater than"
f" or equal to the decode_tp_size: {self._decode_tp_size}"
)
if TransferEngine is None:
raise RuntimeError("mooncake is not available")
@@ -814,11 +818,20 @@ class MooncakeLayerwiseConnectorWorker:
self.engine_id = engine_id
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size
self.dcp_rank = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0
self.tp_group = get_tp_group()
self._decode_tp_size: int | None = None
self.kv_caches: dict[str, torch.Tensor] = {}
self.side_channel_host = get_ip()
self.total_layers = vllm_config.model_config.get_num_layers(vllm_config.parallel_config)
self.use_mla = self.vllm_config.model_config.use_mla
if self.use_mla:
self.total_num_kv_heads = 1
else:
self.total_num_kv_heads = self.vllm_config.model_config.get_total_num_kv_heads()
# Handshake base port
self.side_channel_port = (
@@ -863,23 +876,6 @@ class MooncakeLayerwiseConnectorWorker:
self.k_buffer: torch.Tensor | None = None
self.v_buffer: torch.Tensor | None = None
def _get_prefill_decode_size(self, vllm_config: VllmConfig):
# get prefill tp and dp size from extra config
prefill_parallel_config: dict[str, Any] = vllm_config.kv_transfer_config.get_from_extra_config("prefill", {})
assert "tp_size" in prefill_parallel_config
self._prefill_tp_size = prefill_parallel_config["tp_size"]
assert "dp_size" in prefill_parallel_config
self._prefill_dp_size = prefill_parallel_config["dp_size"]
# get decode tp and dp size from extra config
decode_parallel_config: dict[str, Any] = vllm_config.kv_transfer_config.get_from_extra_config("decode", {})
assert "tp_size" in decode_parallel_config
self._decode_tp_size = decode_parallel_config["tp_size"]
assert "dp_size" in decode_parallel_config
self._decode_dp_size = decode_parallel_config["dp_size"]
def create_kv_buffer(self, first_kv_cache):
if self.pd_head_ratio > 1:
# regesit kv buffer for tp inequal
@@ -977,8 +973,6 @@ class MooncakeLayerwiseConnectorWorker:
kv_cache_base_addr=self.kv_caches_base_addr,
use_mla=self.use_mla,
block_len=self.block_len,
decode_tp_size=self._decode_tp_size,
first_kv_cache=first_kv_cache,
k_buffer=self.k_buffer,
v_buffer=self.v_buffer,
resharding_stream=self.resharding_stream,
@@ -1009,9 +1003,120 @@ class MooncakeLayerwiseConnectorWorker:
else set()
)
if len(done_recving) > 0:
logger.info("Number of completed KV cache recv requests: %d, receive requests: %d", 0, len(done_recving))
logger.info(
f"Number of completed KV cache recv requests: {len(done_recving)}, receive requests: {done_recving}"
)
return set(), done_recving
# {(ip, port)]: {local_block_ids: [], remote_block_ids: {}}}
def _get_kv_split_metadata(self, req_meta, req_idx, req_id):
remote_pcp_size = req_meta.remote_pcp_size
remote_dcp_size = req_meta.remote_dcp_size
remote_tp_size = req_meta.remote_tp_size
remote_hosts = [req_meta.remote_host]
remote_port = req_meta.remote_port
local_transed_tokens = max(req_meta.remote_cache_tokens, req_meta.local_transed_tokens)
# local_transed_tokens tokens that have already been transmitted on the local side
local_computed_tokens = req_meta.local_computed_tokens
prompt_len = req_meta.prompt_len
p_parallel_info = parallel_info(
tp_size=self.tp_size,
pcp_size=self.pcp_size,
dcp_size=self.dcp_size,
pd_head_ratio=self.pd_head_ratio,
use_mla=self.use_mla,
)
d_parallel_info = parallel_info(
tp_size=remote_tp_size,
pcp_size=remote_pcp_size,
dcp_size=remote_dcp_size,
pd_head_ratio=self.pd_head_ratio,
use_mla=self.use_mla,
)
cp_size = self.pcp_size * self.dcp_size
# to_trans_idx all tokens that have been processed up to the current step
if req_meta.chunk_finish:
to_trans_idx = math.ceil(local_computed_tokens / self.block_size)
else:
to_trans_idx = math.floor(local_computed_tokens / self.block_size)
prompt_block_size = math.ceil(prompt_len / self.block_size)
#
num_local_blocks = prompt_block_size // cp_size + int(
(prompt_block_size % cp_size) > (self.pcp_rank * self.dcp_size + self.dcp_rank)
)
already_send_blocks = to_trans_idx // cp_size + int(
(to_trans_idx % cp_size) > (self.pcp_rank * self.dcp_size + self.dcp_rank)
)
if num_local_blocks == already_send_blocks:
req_meta.chunk_finish = True
transed_idx = math.floor(local_transed_tokens / self.block_size)
p_cp_group = get_cp_group(self.tp_size, self.total_num_kv_heads, self.dcp_size)
d_cp_group = get_cp_group(remote_tp_size, self.total_num_kv_heads, remote_dcp_size)
logger.debug(f"Compute cp group for P&D {req_id=} {p_cp_group=} {d_cp_group=}")
cp_ratio = len(p_cp_group) // len(d_cp_group)
if cp_ratio == 0:
selected_p_cp_groups = p_cp_group
selected_d_cp_groups = d_cp_group
else:
x = req_idx % cp_ratio
start = x * len(d_cp_group)
selected_p_cp_groups = p_cp_group[start : (start + len(d_cp_group))]
selected_d_cp_groups = d_cp_group
assert len(selected_p_cp_groups) == len(selected_d_cp_groups)
p_head_group_rank = (self.tp_rank - self.dcp_rank) // self.dcp_size
selected_p_cp_group = []
selected_d_cp_group = []
for idx, cp_group in enumerate(selected_p_cp_groups):
if p_head_group_rank in cp_group: # Check whether the rank is in selected_p_cp_groups
selected_p_cp_group = cp_group
selected_d_cp_group = selected_d_cp_groups[idx]
if len(selected_p_cp_group) == 0:
return {}
logger.debug(
f"MooncakeLayerwiseConnector _get_kv_split_metadata {req_id=} "
f"P-side selected head_group cp group: {selected_p_cp_group}, "
f"D-side selected head_group cp group: {selected_d_cp_group}"
)
context_parallel_parameters_check(
remote_pcp_size, remote_dcp_size, p_parallel_info, d_parallel_info, self.total_num_kv_heads
)
p_rank_block_mapping, d_block_rank_mapping, pd_head_mapping, d_trans_count_mapping = (
get_local_remote_block_port_mappings(
to_trans_idx,
p_parallel_info,
d_parallel_info,
remote_hosts,
remote_port,
selected_p_cp_group,
selected_d_cp_group,
prompt_len,
self.block_size,
req_meta,
self.total_num_kv_heads,
req_id,
)
)
transfer_mappings = get_transfer_mappings(
p_rank_block_mapping,
d_block_rank_mapping,
pd_head_mapping,
d_trans_count_mapping,
req_meta,
p_parallel_info,
req_id,
transed_idx,
to_trans_idx,
self.tp_rank,
self.pcp_rank,
self.dcp_rank,
)
return transfer_mappings
def start_load_kv(self, metadata: MooncakeLayerwiseConnectorMetadata):
"""Start loading KV blocks from remote engine."""
self.current_layer = 0
@@ -1023,31 +1128,29 @@ class MooncakeLayerwiseConnectorWorker:
self.kv_recv_layer_thread.task_tracker[external_req_id] = 0
self.kv_recv_layer_thread.request_map[external_req_id] = req_id
elif self.vllm_config.kv_transfer_config.is_kv_producer:
# select req to send
if self.use_mla or self.use_sparse:
num_need_send = self._decode_tp_size
else:
num_kv_head = self.vllm_config.model_config.hf_config.num_key_value_heads
if self.tp_size <= num_kv_head:
num_need_send = self.tp_size
else:
num_need_send = self._decode_tp_size if self._decode_tp_size >= num_kv_head else num_kv_head
num_replica_groups = self.tp_size // num_need_send if self.tp_size >= num_need_send else 1
replica_group_idx = self.tp_rank % num_replica_groups
req_ids = sorted(list(metadata.requests.keys()))
selected_req_ids = [
req_id for i, req_id in enumerate(req_ids) if i % num_replica_groups == replica_group_idx
]
request_ids = list(metadata.requests.keys())
for req_id in request_ids:
if req_id not in selected_req_ids:
metadata.requests.pop(req_id)
# update trans info
update_metadata = {}
for req_idx, (req_id, req_meta) in enumerate(metadata.requests.items()):
self._decode_tp_size = req_meta.remote_tp_size
transfer_mappings = self._get_kv_split_metadata(req_meta, req_idx, req_id)
assert len(transfer_mappings) <= 1, f"Not support add mutil transfer task for req_id:{req_id}"
update_req_meta = copy.deepcopy(req_meta)
for (host, port), block_dict in transfer_mappings.items():
update_req_meta.remote_host = host
update_req_meta.remote_port = port
update_req_meta.local_block_ids = block_dict["local_block_ids"]
update_req_meta.remote_block_ids = block_dict["remote_block_ids"]
update_req_meta.trans_count = block_dict["trans_count"]
update_metadata[req_id] = update_req_meta
metadata.requests = {}
for req_id, req_meta in update_metadata.items():
metadata.requests[req_id] = update_metadata[req_id]
# update send task trans block info
if self.pd_head_ratio != 1:
send_task = metadata.send_task
send_task.rearrange_block_ids = sorted(
{block_id for req_id in selected_req_ids for block_id in metadata.requests[req_id].local_block_ids}
{block_id for req_id in metadata.requests for block_id in metadata.requests[req_id].local_block_ids}
)
device = self.k_buffer.device # type: ignore
@@ -1070,7 +1173,7 @@ class MooncakeLayerwiseConnectorWorker:
) -> None:
"""MooncakeLayerwiseConnector does not save explicitly."""
if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys():
# enable decode prefix cache
# get reshape and cache event
if self.use_mla or self.use_sparse:
reshape_cache_event = attn_metadata[layer_name].reshape_cache_event
else:
@@ -1156,59 +1259,48 @@ class MooncakeLayerwiseConnectorWorker:
return sock
def update_decoder_info(self, req_id, req_meta):
req_meta_update = copy.deepcopy(req_meta)
if self.use_mla or self.use_sparse:
pd_tp_ratio = self.tp_size // self._decode_tp_size
req_meta_update.remote_port = (
req_meta_update.remote_port + (self.tp_rank // pd_tp_ratio) % self._decode_tp_size
)
else:
req_meta_update.remote_port = (
req_meta_update.remote_port + (self.tp_rank // self.pd_tp_ratio) % self._decode_tp_size
)
if (
req_meta_update.remote_engine_id not in self.remote_kv_caches_base_addr
or req_meta_update.remote_port not in self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id]
req_meta.remote_engine_id not in self.remote_kv_caches_base_addr
or req_meta.remote_port not in self.remote_kv_caches_base_addr[req_meta.remote_engine_id]
):
try:
encoded_data = self.encoder.encode((GET_META_MSG, req_id))
sock = self._get_remote_socket(req_meta_update.remote_host, req_meta_update.remote_port)
path = f"{req_meta_update.remote_host}:{req_meta_update.remote_port}"
sock = self._get_remote_socket(req_meta.remote_host, req_meta.remote_port)
path = f"{req_meta.remote_host}:{req_meta.remote_port}"
ensure_zmq_send(sock, encoded_data, path)
metadata_bytes = ensure_zmq_recv(sock, self.remote_poller, path)
agent_meta = self.decoder.decode(metadata_bytes)
except Exception as e:
logger.error(
f"Query to port and kv base addr for request {req_id} from "
f"{req_meta_update.remote_host}:{req_meta_update.remote_port} fail with error: {e}"
f"Query to port and kv base addr for request {req_id}"
f"from {req_meta.remote_host}:{req_meta.remote_port}"
f"fail with error: {e}"
)
assert req_meta_update.remote_engine_id != self.engine_id, (
f"Conflict engine id {req_meta_update.remote_engine_id} with local engine id {self.local_engine_id}."
assert req_meta.remote_engine_id != self.engine_id, (
f"Conflict engine id {req_meta.remote_engine_id} with local engine id {self.local_engine_id}."
)
self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id][req_meta_update.remote_port] = (
self.remote_kv_caches_base_addr[req_meta.remote_engine_id][req_meta.remote_port] = (
agent_meta.kv_caches_base_addr
)
self.remote_te_port[req_meta_update.remote_engine_id][req_meta_update.remote_port] = agent_meta.te_rpc_port
self.remote_te_port[req_meta.remote_engine_id][req_meta.remote_port] = agent_meta.te_rpc_port
logger.info(
f"Query to port and kv base addr for request {req_id} from "
f"{req_meta_update.remote_host}:{req_meta_update.remote_port} success "
f"{agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}"
f"Query to port and kv base addr for request {req_id}"
f"from {req_meta.remote_host}:{req_meta.remote_port}"
f"success {agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}"
)
if self.pd_head_ratio > 1:
# for tp inequal, pre-create link to prevent alltoall out of memory
session_id = f"{req_meta_update.remote_host}:{agent_meta.te_rpc_port}"
session_id = f"{req_meta.remote_host}:{agent_meta.te_rpc_port}"
ret = self.engine.batch_transfer_sync_write(
session_id, [self.kv_caches_base_addr[0]], [agent_meta.kv_caches_base_addr[0]], [128]
)
if ret < 0:
logger.error(f"Mooncake transfer failed to create link to device {session_id}")
req_meta_update.remote_te_rpc_port = self.remote_te_port[req_meta_update.remote_engine_id][
req_meta_update.remote_port
req_meta.remote_te_rpc_port = self.remote_te_port[req_meta.remote_engine_id][req_meta.remote_port]
req_meta.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[req_meta.remote_engine_id][
req_meta.remote_port
]
req_meta_update.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id][
req_meta_update.remote_port
]
return req_meta_update
return req_meta
def send_done_send_signal(self, req_id, req_meta):
external_req_id = get_external_request_id(req_id)
@@ -1221,7 +1313,7 @@ class MooncakeLayerwiseConnectorWorker:
try:
path = make_zmq_path("tcp", req_meta.remote_host, req_meta.remote_port)
msg_encoder = msgspec.msgpack.Encoder()
encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id))
encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id, req_meta.trans_count))
with zmq_ctx(zmq.REQ, path) as sock: # type: ignore
ensure_zmq_send(sock, encoded_data, f"{req_meta.remote_host}:{req_meta.remote_port}")
ack = sock.recv()

View File

@@ -1,7 +1,12 @@
import math
import os
from collections import defaultdict
from dataclasses import dataclass
from typing import Any
import torch
import torch.distributed as dist
from vllm.logger import logger
from vllm_ascend.distributed.parallel_state import get_p_tp_group
@@ -50,3 +55,241 @@ def get_transfer_timeout_value():
hccl_rdma_timeout = int(os.getenv("HCCL_RDMA_TIMEOUT", "20")) # type: ignore
hccl_rdma_retry_cnt = int(os.getenv("HCCL_RDMA_RETRY_CNT", "7")) # type: ignore
return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 + 3000)
@dataclass
class parallel_info:
tp_size: int
pcp_size: int
dcp_size: int
use_mla: bool
pd_head_ratio: int
def get_cp_group(tp: int, heads: int, dcp: int):
# Partition the second dimension of [pcp][head_group][dcp] to obtain a complete head group
# head_group is all blocks for request in the same head
# tp8 dcp2 heads4 return[[0,1,2,3]]
# tp8 dcp1 heads4 return[[0,2,4,6],[1,3,5,7]]
step = tp // heads
if step == 0:
return [[i for i in range(tp // dcp)]]
else:
return [
set([k // dcp for h in range(heads) for k in range(h * step + i * dcp, h * step + (i + 1) * dcp)])
for i in range(step // dcp)
]
def context_parallel_parameters_check(
remote_pcp_size: int,
remote_dcp_size: int,
p_parallel_info: parallel_info,
d_parallel_info: parallel_info,
total_num_kv_heads: int,
):
# Check whether the pcpdcp ratio is supported
assert (p_parallel_info.pcp_size * p_parallel_info.dcp_size) % (remote_pcp_size * remote_dcp_size) == 0
if not p_parallel_info.use_mla:
p_node_heads_per_rank = math.ceil(total_num_kv_heads / p_parallel_info.tp_size)
d_node_heads_per_rank = math.ceil(total_num_kv_heads / d_parallel_info.dcp_size)
assert d_node_heads_per_rank % p_node_heads_per_rank == 0
def get_tp_rank_head_mapping(num_key_value_heads: int, tp_size: int):
# Get the head_idx corresponding to the tp_rank, {tp_rank:[head_indx]}
mapping = {}
if tp_size <= num_key_value_heads:
if num_key_value_heads % tp_size != 0:
raise ValueError(f"Number of heads ({num_key_value_heads}) cannot be evenly divided by TP ({tp_size}).")
heads_per_rank = num_key_value_heads // tp_size
for rank in range(tp_size):
start_idx = rank * heads_per_rank
end_idx = start_idx + heads_per_rank
mapping[rank] = list(range(start_idx, end_idx))
else:
if tp_size % num_key_value_heads != 0:
raise ValueError(f"Number of heads ({num_key_value_heads}) cannot be evenly divided by TP ({tp_size}).")
ranks_per_head = tp_size // num_key_value_heads
for rank in range(tp_size):
head_idx = rank // ranks_per_head
mapping[rank] = [head_idx]
return mapping
def get_head_group_mapping(num_key_value_heads: int, tp_size: int, num_groups: int, select_cp_group: list[int]):
# Get the mapping dictionary, where the key is head_group_rank and the value is head_idx
if tp_size % num_groups != 0:
raise ValueError(
f"Total number of devices ({tp_size}) cannot be divided by the number of groups ({num_groups})."
)
ranks_per_group = tp_size // num_groups
tp_mapping = get_tp_rank_head_mapping(num_key_value_heads, tp_size)
group_mapping = {}
for group_rank in range(num_groups):
if group_rank in select_cp_group:
start_rank = group_rank * ranks_per_group
end_rank = start_rank + ranks_per_group
heads_set = set()
for rank in range(start_rank, end_rank):
heads_set.update(tp_mapping[rank])
group_mapping[group_rank] = sorted(list(heads_set))
return group_mapping
def get_local_remote_block_port_mappings(
to_trans_idx: int,
p_parallel_info: parallel_info,
d_parallel_info: parallel_info,
d_hosts: list[str],
d_port: int,
selected_p_cp_group: list[int],
selected_d_cp_group: list[int],
prompt_len: int,
block_size: int,
req_meta,
total_num_kv_heads: int,
req_id: str,
):
p_head_group_size = p_parallel_info.tp_size // p_parallel_info.dcp_size
d_head_group_size = d_parallel_info.tp_size // d_parallel_info.dcp_size
world_size = d_parallel_info.pcp_size * d_head_group_size * d_parallel_info.dcp_size
# Compute which logic_block_idx corresponds to each tp_rank
p_rank_block_mapping: list[list[list[list[int]]]] = [
[[[] for _ in range(p_parallel_info.dcp_size)] for _ in range(p_head_group_size)]
for _ in range(p_parallel_info.pcp_size)
]
for logic_block_idx in range(to_trans_idx):
pcp_rank = (logic_block_idx // p_parallel_info.dcp_size) % p_parallel_info.pcp_size
dcp_rank = logic_block_idx % p_parallel_info.dcp_size
for p_head_group_rank in range(p_head_group_size):
if p_head_group_rank in selected_p_cp_group:
p_rank_block_mapping[pcp_rank][p_head_group_rank][dcp_rank].append(logic_block_idx)
# Find the remote device that holds the logic_block_idx
d_block_rank_mapping: dict[int, dict[int, dict[str, Any]]] = defaultdict(lambda: defaultdict(dict))
for logic_block_idx in range(to_trans_idx):
pcp_rank = (logic_block_idx // d_parallel_info.dcp_size) % d_parallel_info.pcp_size
for d_head_group_rank in range(d_head_group_size):
if d_head_group_rank in selected_d_cp_group:
dcp_rank = logic_block_idx % d_parallel_info.dcp_size
world_rank = (
pcp_rank * d_head_group_size * d_parallel_info.dcp_size
+ d_head_group_rank * d_parallel_info.dcp_size
+ dcp_rank
)
world_size = d_parallel_info.pcp_size * d_head_group_size * d_parallel_info.dcp_size
host = d_hosts[(len(d_hosts) * world_rank) // world_size]
port = d_port + world_rank
block_idx = (logic_block_idx - (pcp_rank * d_parallel_info.pcp_size + dcp_rank)) // (
d_parallel_info.pcp_size * d_parallel_info.dcp_size
)
d_block_rank_mapping[logic_block_idx][d_head_group_rank] = {
"pcp_rank": pcp_rank,
"dcp_rank": dcp_rank,
"host": host,
"port": port,
"block_idx": block_idx,
}
# Get how many times each device should receive done_single for this request
d_trans_count_mapping = {}
trans_block_size = math.ceil(prompt_len / block_size) # Total number of blocks
transed_block_size = math.ceil(req_meta.remote_cache_tokens / block_size) # Number of prefix cache hit blocks
d_cp_size = d_parallel_info.pcp_size * d_parallel_info.dcp_size
for d_pcp_rank in range(d_parallel_info.pcp_size):
for d_head_group_rank in range(d_head_group_size):
for d_dcp_rank in range(d_parallel_info.dcp_size):
if trans_block_size >= (p_parallel_info.pcp_size * p_parallel_info.dcp_size):
trans_count = (p_parallel_info.pcp_size * p_parallel_info.dcp_size) // d_cp_size
else:
current_rank_idx = d_pcp_rank * d_parallel_info.dcp_size + d_dcp_rank
total_global_blocks = transed_block_size + trans_block_size
target_total_count = total_global_blocks // d_cp_size
if current_rank_idx < (total_global_blocks % d_cp_size):
target_total_count += 1
prev_processed_count = transed_block_size // d_cp_size
if current_rank_idx < (transed_block_size % d_cp_size):
prev_processed_count += 1
trans_count = target_total_count - prev_processed_count
world_rank = (
d_pcp_rank * d_head_group_size * d_parallel_info.dcp_size
+ d_head_group_rank * d_parallel_info.dcp_size
+ d_dcp_rank
)
host = d_hosts[(len(d_hosts) * world_rank) // world_size]
port = d_port + world_rank
d_trans_count_mapping[(host, port)] = trans_count * p_parallel_info.pd_head_ratio
# Compute the mapping between local and remote head_group_rank
p_tp_rank_head_mapping = get_head_group_mapping(
total_num_kv_heads, p_parallel_info.tp_size, p_head_group_size, selected_p_cp_group
)
d_tp_rank_head_mapping = get_head_group_mapping(
total_num_kv_heads, d_parallel_info.tp_size, d_head_group_size, selected_d_cp_group
)
head_to_d_groups = defaultdict(set)
for d_rank, heads in d_tp_rank_head_mapping.items():
for head in heads:
head_to_d_groups[head].add(d_rank)
pd_head_mapping = {}
for p_rank, p_heads in p_tp_rank_head_mapping.items():
target_d_ranks = set()
for head in p_heads:
if head in head_to_d_groups:
target_d_ranks.update(head_to_d_groups[head])
else:
logger.info(f"Warning: Head {head} exists in P but not in D mapping.")
pd_head_mapping[p_rank] = sorted(list(target_d_ranks))
logger.debug(
f"MooncakeLayerwiseConnector _get_kv_split_metadata {req_id=} "
f"P-side logic_block to rank mapping: {p_rank_block_mapping}, "
f"D-side logic_block to rank mapping: {d_block_rank_mapping}, "
f"P&D head_group_rank mapping: {pd_head_mapping}"
)
return p_rank_block_mapping, d_block_rank_mapping, pd_head_mapping, d_trans_count_mapping
def get_transfer_mappings(
p_rank_block_mapping: list[list[list[list[int]]]],
d_block_rank_mapping: dict[int, dict[int, dict[str, Any]]],
pd_head_mapping: dict[int, set],
d_trans_count_mapping: dict[tuple[str, int], int],
req_meta,
p_parallel_info: parallel_info,
req_id: str,
transed_idx: int,
to_trans_idx: int,
tp_rank: int,
pcp_rank: int,
dcp_rank: int,
):
transfer_mappings: dict[tuple[str, int], dict[str, Any]] = {}
p_head_group_rank = (tp_rank - dcp_rank) // p_parallel_info.dcp_size
p_block_idxs: list[int] = p_rank_block_mapping[pcp_rank][p_head_group_rank][dcp_rank]
for p_block_idx, logic_block_idx in enumerate(p_block_idxs):
if logic_block_idx < transed_idx or logic_block_idx >= to_trans_idx:
continue
for d_head_group_rank in pd_head_mapping[p_head_group_rank]:
p_block_id = req_meta.local_block_ids[p_block_idx]
remote_host = d_block_rank_mapping[logic_block_idx][d_head_group_rank]["host"]
remote_port = d_block_rank_mapping[logic_block_idx][d_head_group_rank]["port"]
d_block_idx = d_block_rank_mapping[logic_block_idx][d_head_group_rank]["block_idx"]
d_block_id = req_meta.remote_block_ids[d_block_idx]
if (remote_host, remote_port) not in transfer_mappings:
transfer_mappings[(remote_host, remote_port)] = {
"local_block_ids": [],
"remote_block_ids": [],
"trans_count": 0,
}
transfer_mappings[(remote_host, remote_port)]["local_block_ids"].append(p_block_id)
transfer_mappings[(remote_host, remote_port)]["remote_block_ids"].append(d_block_id)
for (host, port), block_dict in transfer_mappings.items():
block_dict["trans_count"] = d_trans_count_mapping[(host, port)]
logger.debug(f"MooncakeLayerwiseConnector Request {req_id} transfer tasks: {transfer_mappings}")
return transfer_mappings