[Recover] [Bugfix] support mtp kv transfer and pp partition by hand in kv transfer (#4892) (revert in #4981) (#5511)

PR #4892 was revert in #4981, we recover it now. For the potential bug
break deepseek3.2 in PD case, we will find it out and fix it.

- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1

---------

Signed-off-by: lidenghui <lidenghui1110@gmail.com>
This commit is contained in:
lidenghui1110
2026-01-04 16:49:33 +08:00
committed by GitHub
parent 7c210225a2
commit d462577504
2 changed files with 72 additions and 15 deletions

View File

@@ -246,7 +246,8 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
block_len=[1024, 2048],
ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches)
kv_caches=self.kv_caches,
prefill_pp_layer_partition=None)
def test_add_request(self):
test_req = {
@@ -300,7 +301,8 @@ class TestSocketManagement(unittest.TestCase):
block_len=[1024, 2048],
ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches)
kv_caches=self.kv_caches,
prefill_pp_layer_partition=None)
self.thread.remote_sockets = defaultdict(deque)
self.thread.remote_poller = MagicMock()
@@ -358,7 +360,8 @@ class TestCoreFunctionality(unittest.TestCase):
block_len=[1024, 2048],
ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches)
kv_caches=self.kv_caches,
prefill_pp_layer_partition=None)
self.thread.request_queue = self.mock_queue
self.test_req = {
"request_id": "req1",
@@ -444,7 +447,8 @@ class TestMetadataHandling(unittest.TestCase):
block_len=[1024, 2048],
ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches)
kv_caches=self.kv_caches,
prefill_pp_layer_partition=None)
self.test_metadata = MooncakeAgentMetadata(
engine_id="remote_engine",
te_rpc_port=9090,
@@ -509,7 +513,8 @@ class TestMainThreadLoop(unittest.TestCase):
block_len=[1024, 2048],
ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches)
kv_caches=self.kv_caches,
prefill_pp_layer_partition=None)
self.thread.request_queue = queue.Queue()
@patch.object(KVCacheRecvingThread, '_handle_request')
@@ -546,6 +551,7 @@ class MockVllmConfig:
self.parallel_config = MagicMock()
self.cache_config = MagicMock()
self.kv_transfer_config = MagicMock()
self.speculative_config = MagicMock()
self.model_config.use_mla = True
self.parallel_config.tensor_parallel_size = 2
self.parallel_config.data_parallel_rank = 0

View File

@@ -292,12 +292,20 @@ class KVCacheSendingThread(threading.Thread):
class KVCacheRecvingThread(threading.Thread):
def __init__(self, tp_rank: int, tp_size: int, _prefill_pp_size: int,
engine: TransferEngine, local_engine_id: str,
local_handshake_port: int, side_channel_port: int,
local_kv_caches_base_addr: list[int], block_len: list[int],
ready_event: threading.Event, vllm_config: VllmConfig,
kv_caches: dict[str, Any]):
def __init__(self,
tp_rank: int,
tp_size: int,
_prefill_pp_size: int,
engine: TransferEngine,
local_engine_id: str,
local_handshake_port: int,
side_channel_port: int,
local_kv_caches_base_addr: list[int],
block_len: list[int],
ready_event: threading.Event,
vllm_config: VllmConfig,
kv_caches: dict[str, Any],
prefill_pp_layer_partition: Optional[str] = None):
super().__init__(daemon=True, name="KVCacheRecvingThread")
self.tp_rank = tp_rank
self.tp_size = tp_size
@@ -337,6 +345,14 @@ class KVCacheRecvingThread(threading.Thread):
self.vllm_config = vllm_config
self.model_config = self.vllm_config.model_config
self.block_size = self.vllm_config.cache_config.block_size
self.num_layers = self.model_config.hf_config.num_hidden_layers
self.pp_layer_indices = {
rank:
get_prefill_pp_indices(self.num_layers, rank,
self._prefill_pp_size,
prefill_pp_layer_partition)
for rank in range(self._prefill_pp_size)
}
if not is_vl_model(vllm_config):
if self.use_mla:
self.k_head_dim = self.model_config.hf_text_config.kv_lora_rank
@@ -490,9 +506,13 @@ class KVCacheRecvingThread(threading.Thread):
remote_kv_caches_base_addrs = \
self.kv_caches_base_addr[remote_engine_id][remote_handshake_port]
num_layers = self.model_config.hf_text_config.num_hidden_layers
first_layer_index, end_layer_index = get_pp_indices(
num_layers, prefill_pp_rank, self._prefill_pp_size)
first_layer_index, end_layer_index = self.pp_layer_indices[
prefill_pp_rank]
# support MTP layer kv transfer
if self.vllm_config.speculative_config is not None:
# all MTP layer use the same kv cache layer, so only need to transfer once
if prefill_pp_rank == self._prefill_pp_size - 1:
end_layer_index = end_layer_index + 1
num_cache_per_layer = len(list(
self.kv_caches.values())[0]) # Number of KV caches per layer
local_kv_caches_base_addrs = \
@@ -1161,6 +1181,8 @@ class MooncakeConnectorWorker:
# get prefill pp size from extra config
self._decode_pp_size = decode_parallel_config.get("pp_size", 1)
assert self._decode_pp_size == 1, "decode pp size must be 1"
self._prefill_pp_layer_partition = prefill_parallel_config.get(
"pp_layer_partition", None)
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data."""
@@ -1270,7 +1292,8 @@ class MooncakeConnectorWorker:
self.tp_rank, self.tp_size, self._prefill_pp_size, self.engine,
self.engine_id, self.handshake_port, self.side_channel_port,
kv_caches_base_addr, self.block_len, ready_event,
self.vllm_config, self.kv_caches)
self.vllm_config, self.kv_caches,
self._prefill_pp_layer_partition)
self.kv_recv_thread.start()
start_wait_time = time.time()
@@ -1761,3 +1784,31 @@ def ensure_zmq_recv(
raise RuntimeError(
f"Failed to receive data after {max_retries} "
f"retries: {e}")
# decode node should know pp_partition_layer in prefill node,
# it is configured in kv_transfer_config by partition_list_str,
# default using vllm layer split algorithm.
def get_prefill_pp_indices(
num_hidden_layers: int,
pp_rank: int,
pp_size: int,
partition_list_str: Optional[str] = None) -> tuple[int, int]:
if partition_list_str is None:
return get_pp_indices(num_hidden_layers, pp_rank, pp_size)
else:
try:
partitions = [
int(layer) for layer in partition_list_str.split(",")
]
except ValueError as err:
raise ValueError("Invalid partition string: {}".format(
partition_list_str)) from err
if len(partitions) != pp_size:
raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
if sum(partitions) != num_hidden_layers:
raise ValueError(
f"{sum(partitions)=} does not match {num_hidden_layers=}.")
start_layer = sum(partitions[:pp_rank])
end_layer = start_layer + partitions[pp_rank]
return (start_layer, end_layer)