[Bugfix] support mtp kv transfer and pp partition by hand in kv transfer (#4892)

### What this PR does / why we need it?
Current mooncake connector has following problems with PP and MTP
enabled:
1. MTP layer kv caches are not transfered, it may cause decreasing of
accept ratio: This PR add MTP layer indices for last PP stage after
calculating end_layer in transfer_kv_cache
2. While MTP enabled, PP layers divided by default may cause imbalance
between stages, we need to use `VLLM_PP_LAYER_PARTITION` environment to
make it balance by hand, but in mooncake connector kv transfer, decode
doesn't know the partition of prefill node: This PR add config
`pp_layer_partition` in `kv_connector_extra_config` to make decode node
acquire the partition information of prefill node.

### Does this PR introduce _any_ user-facing change?
When prefill using `VLLM_PP_LAYER_PARTITION` environment, add
`pp_layer_partition` in `kv_connector_extra_config` like below:
```
export VLLM_PP_LAYER_PARTITION=33,28
"kv_connector_extra_config": {
    "use_ascend_direct": true,
    "prefill": {
            "dp_size": 1,
            "tp_size": 8,
            "pp_size": 2,
            "pp_layer_partition": "33,28"
     },
     "decode": {
            "dp_size": 16,
            "tp_size": 1,
            "pp_size": 1
     }
}
```

### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: lidenghui <lidenghui1110@gmail.com>
This commit is contained in:
lidenghui1110
2025-12-11 17:23:21 +08:00
committed by GitHub
parent a47aa4da2f
commit 332b547728
2 changed files with 71 additions and 14 deletions

View File

@@ -242,7 +242,8 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
block_len=[1024, 2048],
ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches)
kv_caches=self.kv_caches,
prefill_pp_layer_partition=None)
def test_add_request(self):
test_req = {
@@ -295,7 +296,8 @@ class TestSocketManagement(unittest.TestCase):
block_len=[1024, 2048],
ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches)
kv_caches=self.kv_caches,
prefill_pp_layer_partition=None)
self.thread.remote_sockets = defaultdict(deque)
self.thread.remote_poller = MagicMock()
@@ -352,7 +354,8 @@ class TestCoreFunctionality(unittest.TestCase):
block_len=[1024, 2048],
ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches)
kv_caches=self.kv_caches,
prefill_pp_layer_partition=None)
self.thread.request_queue = self.mock_queue
self.test_req = {
"request_id": "req1",
@@ -434,7 +437,8 @@ class TestMetadataHandling(unittest.TestCase):
block_len=[1024, 2048],
ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches)
kv_caches=self.kv_caches,
prefill_pp_layer_partition=None)
self.test_metadata = MooncakeAgentMetadata(
engine_id="remote_engine",
te_rpc_port=9090,
@@ -498,7 +502,8 @@ class TestMainThreadLoop(unittest.TestCase):
block_len=[1024, 2048],
ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches)
kv_caches=self.kv_caches,
prefill_pp_layer_partition=None)
self.thread.request_queue = queue.Queue()
@patch.object(KVCacheRecvingThread, '_handle_request')
@@ -535,6 +540,7 @@ class MockVllmConfig:
self.parallel_config = MagicMock()
self.cache_config = MagicMock()
self.kv_transfer_config = MagicMock()
self.speculative_config = MagicMock()
self.model_config.use_mla = True
self.parallel_config.tensor_parallel_size = 2
self.parallel_config.data_parallel_rank = 0

View File

@@ -271,12 +271,19 @@ class KVCacheSendingThread(threading.Thread):
class KVCacheRecvingThread(threading.Thread):
def __init__(self, tp_rank: int, tp_size: int, _prefill_pp_size: int,
engine: TransferEngine, local_engine_id: str,
def __init__(self,
tp_rank: int,
tp_size: int,
_prefill_pp_size: int,
engine: TransferEngine,
local_engine_id: str,
local_handshake_port: int,
local_kv_caches_base_addr: list[int], block_len: list[int],
ready_event: threading.Event, vllm_config: VllmConfig,
kv_caches: dict[str, Any]):
local_kv_caches_base_addr: list[int],
block_len: list[int],
ready_event: threading.Event,
vllm_config: VllmConfig,
kv_caches: dict[str, Any],
prefill_pp_layer_partition: Optional[str] = None):
super().__init__(daemon=True, name="KVCacheRecvingThread")
self.tp_rank = tp_rank
self.tp_size = tp_size
@@ -315,6 +322,14 @@ class KVCacheRecvingThread(threading.Thread):
self.vllm_config = vllm_config
self.model_config = self.vllm_config.model_config
self.block_size = self.vllm_config.cache_config.block_size
self.num_layers = self.model_config.hf_config.num_hidden_layers
self.pp_layer_indices = {
rank:
get_prefill_pp_indices(self.num_layers, rank,
self._prefill_pp_size,
prefill_pp_layer_partition)
for rank in range(self._prefill_pp_size)
}
if self.use_mla:
self.k_head_dim = self.model_config.hf_config.kv_lora_rank
self.v_head_dim = self.model_config.hf_config.qk_rope_head_dim
@@ -435,9 +450,14 @@ class KVCacheRecvingThread(threading.Thread):
remote_kv_caches_base_addrs = \
self.kv_caches_base_addr[remote_engine_id][remote_handshake_port]
num_layers = self.model_config.hf_config.num_hidden_layers
first_layer_index, end_layer_index = get_pp_indices(
num_layers, prefill_pp_rank, self._prefill_pp_size)
first_layer_index, end_layer_index = self.pp_layer_indices[
prefill_pp_rank]
# support MTP layer kv transfer
if self.vllm_config.speculative_config is not None:
num_speculative_tokens = self.vllm_config.speculative_config.num_speculative_tokens
num_speculative_tokens = 0 if num_speculative_tokens is None else num_speculative_tokens
if prefill_pp_rank == self._prefill_pp_size - 1:
end_layer_index = end_layer_index + num_speculative_tokens
num_cache_per_layer = len(list(
self.kv_caches.values())[0]) # Number of KV caches per layer
local_kv_caches_base_addrs = \
@@ -1020,6 +1040,8 @@ class MooncakeConnectorWorker:
# get prefill pp size from extra config
self._decode_pp_size = decode_parallel_config.get("pp_size", 1)
assert self._decode_pp_size == 1, "decode pp size must be 1"
self._prefill_pp_layer_partition = prefill_parallel_config.get(
"pp_layer_partition", None)
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data."""
@@ -1126,7 +1148,8 @@ class MooncakeConnectorWorker:
self.kv_recv_thread = KVCacheRecvingThread(
self.tp_rank, self.tp_size, self._prefill_pp_size, self.engine,
self.engine_id, self.handshake_port, kv_caches_base_addr,
self.block_len, ready_event, self.vllm_config, self.kv_caches)
self.block_len, ready_event, self.vllm_config, self.kv_caches,
self._prefill_pp_layer_partition)
self.kv_recv_thread.start()
ready_event.wait()
@@ -1455,3 +1478,31 @@ def ensure_zmq_recv(
raise RuntimeError(
f"Failed to receive data after {max_retries} "
f"retries: {e}")
# decode node should know pp_partition_layer in prefill node,
# it is configured in kv_transfer_config by partition_list_str,
# default using vllm layer split algorithm.
def get_prefill_pp_indices(
num_hidden_layers: int,
pp_rank: int,
pp_size: int,
partition_list_str: Optional[str] = None) -> tuple[int, int]:
if partition_list_str is None:
return get_pp_indices(num_hidden_layers, pp_rank, pp_size)
else:
try:
partitions = [
int(layer) for layer in partition_list_str.split(",")
]
except ValueError as err:
raise ValueError("Invalid partition string: {}".format(
partition_list_str)) from err
if len(partitions) != pp_size:
raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
if sum(partitions) != num_hidden_layers:
raise ValueError(
f"{sum(partitions)=} does not match {num_hidden_layers=}.")
start_layer = sum(partitions[:pp_rank])
end_layer = start_layer + partitions[pp_rank]
return (start_layer, end_layer)