[Bugfix] support mtp kv transfer and pp partition by hand in kv transfer (#4892)

### What this PR does / why we need it?
Current mooncake connector has following problems with PP and MTP
enabled:
1. MTP layer kv caches are not transfered, it may cause decreasing of
accept ratio: This PR add MTP layer indices for last PP stage after
calculating end_layer in transfer_kv_cache
2. While MTP enabled, PP layers divided by default may cause imbalance
between stages, we need to use `VLLM_PP_LAYER_PARTITION` environment to
make it balance by hand, but in mooncake connector kv transfer, decode
doesn't know the partition of prefill node: This PR add config
`pp_layer_partition` in `kv_connector_extra_config` to make decode node
acquire the partition information of prefill node.

### Does this PR introduce _any_ user-facing change?
When prefill using `VLLM_PP_LAYER_PARTITION` environment, add
`pp_layer_partition` in `kv_connector_extra_config` like below:
```
export VLLM_PP_LAYER_PARTITION=33,28
"kv_connector_extra_config": {
    "use_ascend_direct": true,
    "prefill": {
            "dp_size": 1,
            "tp_size": 8,
            "pp_size": 2,
            "pp_layer_partition": "33,28"
     },
     "decode": {
            "dp_size": 16,
            "tp_size": 1,
            "pp_size": 1
     }
}
```

### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: lidenghui <lidenghui1110@gmail.com>
This commit is contained in:
lidenghui1110
2025-12-11 17:23:21 +08:00
committed by GitHub
parent a47aa4da2f
commit 332b547728
2 changed files with 71 additions and 14 deletions

View File

@@ -242,7 +242,8 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
block_len=[1024, 2048],
ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches)
kv_caches=self.kv_caches,
prefill_pp_layer_partition=None)
def test_add_request(self):
test_req = {
@@ -295,7 +296,8 @@ class TestSocketManagement(unittest.TestCase):
block_len=[1024, 2048],
ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches)
kv_caches=self.kv_caches,
prefill_pp_layer_partition=None)
self.thread.remote_sockets = defaultdict(deque)
self.thread.remote_poller = MagicMock()
@@ -352,7 +354,8 @@ class TestCoreFunctionality(unittest.TestCase):
block_len=[1024, 2048],
ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches)
kv_caches=self.kv_caches,
prefill_pp_layer_partition=None)
self.thread.request_queue = self.mock_queue
self.test_req = {
"request_id": "req1",
@@ -434,7 +437,8 @@ class TestMetadataHandling(unittest.TestCase):
block_len=[1024, 2048],
ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches)
kv_caches=self.kv_caches,
prefill_pp_layer_partition=None)
self.test_metadata = MooncakeAgentMetadata(
engine_id="remote_engine",
te_rpc_port=9090,
@@ -498,7 +502,8 @@ class TestMainThreadLoop(unittest.TestCase):
block_len=[1024, 2048],
ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches)
kv_caches=self.kv_caches,
prefill_pp_layer_partition=None)
self.thread.request_queue = queue.Queue()
@patch.object(KVCacheRecvingThread, '_handle_request')
@@ -535,6 +540,7 @@ class MockVllmConfig:
self.parallel_config = MagicMock()
self.cache_config = MagicMock()
self.kv_transfer_config = MagicMock()
self.speculative_config = MagicMock()
self.model_config.use_mla = True
self.parallel_config.tensor_parallel_size = 2
self.parallel_config.data_parallel_rank = 0