Revert "[Bugfix] support mtp kv transfer and pp partition by hand in kv transfer (#4892)" (#4981)

This reverts commit 332b547728.

This break deepseek3.2 in PD case.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
This commit is contained in:
wangxiyuan
2025-12-13 18:58:55 +08:00
committed by GitHub
parent 31c94b7e7b
commit 5211e991ad
2 changed files with 14 additions and 71 deletions

View File

@@ -242,8 +242,7 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
block_len=[1024, 2048],
ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches,
prefill_pp_layer_partition=None)
kv_caches=self.kv_caches)
def test_add_request(self):
test_req = {
@@ -296,8 +295,7 @@ class TestSocketManagement(unittest.TestCase):
block_len=[1024, 2048],
ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches,
prefill_pp_layer_partition=None)
kv_caches=self.kv_caches)
self.thread.remote_sockets = defaultdict(deque)
self.thread.remote_poller = MagicMock()
@@ -354,8 +352,7 @@ class TestCoreFunctionality(unittest.TestCase):
block_len=[1024, 2048],
ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches,
prefill_pp_layer_partition=None)
kv_caches=self.kv_caches)
self.thread.request_queue = self.mock_queue
self.test_req = {
"request_id": "req1",
@@ -437,8 +434,7 @@ class TestMetadataHandling(unittest.TestCase):
block_len=[1024, 2048],
ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches,
prefill_pp_layer_partition=None)
kv_caches=self.kv_caches)
self.test_metadata = MooncakeAgentMetadata(
engine_id="remote_engine",
te_rpc_port=9090,
@@ -502,8 +498,7 @@ class TestMainThreadLoop(unittest.TestCase):
block_len=[1024, 2048],
ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches,
prefill_pp_layer_partition=None)
kv_caches=self.kv_caches)
self.thread.request_queue = queue.Queue()
@patch.object(KVCacheRecvingThread, '_handle_request')
@@ -540,7 +535,6 @@ class MockVllmConfig:
self.parallel_config = MagicMock()
self.cache_config = MagicMock()
self.kv_transfer_config = MagicMock()
self.speculative_config = MagicMock()
self.model_config.use_mla = True
self.parallel_config.tensor_parallel_size = 2
self.parallel_config.data_parallel_rank = 0

View File

@@ -276,19 +276,12 @@ class KVCacheSendingThread(threading.Thread):
class KVCacheRecvingThread(threading.Thread):
def __init__(self,
tp_rank: int,
tp_size: int,
_prefill_pp_size: int,
engine: TransferEngine,
local_engine_id: str,
def __init__(self, tp_rank: int, tp_size: int, _prefill_pp_size: int,
engine: TransferEngine, local_engine_id: str,
local_handshake_port: int,
local_kv_caches_base_addr: list[int],
block_len: list[int],
ready_event: threading.Event,
vllm_config: VllmConfig,
kv_caches: dict[str, Any],
prefill_pp_layer_partition: Optional[str] = None):
local_kv_caches_base_addr: list[int], block_len: list[int],
ready_event: threading.Event, vllm_config: VllmConfig,
kv_caches: dict[str, Any]):
super().__init__(daemon=True, name="KVCacheRecvingThread")
self.tp_rank = tp_rank
self.tp_size = tp_size
@@ -327,14 +320,6 @@ class KVCacheRecvingThread(threading.Thread):
self.vllm_config = vllm_config
self.model_config = self.vllm_config.model_config
self.block_size = self.vllm_config.cache_config.block_size
self.num_layers = self.model_config.hf_config.num_hidden_layers
self.pp_layer_indices = {
rank:
get_prefill_pp_indices(self.num_layers, rank,
self._prefill_pp_size,
prefill_pp_layer_partition)
for rank in range(self._prefill_pp_size)
}
if self.use_mla:
self.k_head_dim = self.model_config.hf_config.kv_lora_rank
self.v_head_dim = self.model_config.hf_config.qk_rope_head_dim
@@ -455,14 +440,9 @@ class KVCacheRecvingThread(threading.Thread):
remote_kv_caches_base_addrs = \
self.kv_caches_base_addr[remote_engine_id][remote_handshake_port]
first_layer_index, end_layer_index = self.pp_layer_indices[
prefill_pp_rank]
# support MTP layer kv transfer
if self.vllm_config.speculative_config is not None:
num_speculative_tokens = self.vllm_config.speculative_config.num_speculative_tokens
num_speculative_tokens = 0 if num_speculative_tokens is None else num_speculative_tokens
if prefill_pp_rank == self._prefill_pp_size - 1:
end_layer_index = end_layer_index + num_speculative_tokens
num_layers = self.model_config.hf_config.num_hidden_layers
first_layer_index, end_layer_index = get_pp_indices(
num_layers, prefill_pp_rank, self._prefill_pp_size)
num_cache_per_layer = len(list(
self.kv_caches.values())[0]) # Number of KV caches per layer
local_kv_caches_base_addrs = \
@@ -1045,8 +1025,6 @@ class MooncakeConnectorWorker:
# get prefill pp size from extra config
self._decode_pp_size = decode_parallel_config.get("pp_size", 1)
assert self._decode_pp_size == 1, "decode pp size must be 1"
self._prefill_pp_layer_partition = prefill_parallel_config.get(
"pp_layer_partition", None)
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data."""
@@ -1153,8 +1131,7 @@ class MooncakeConnectorWorker:
self.kv_recv_thread = KVCacheRecvingThread(
self.tp_rank, self.tp_size, self._prefill_pp_size, self.engine,
self.engine_id, self.handshake_port, kv_caches_base_addr,
self.block_len, ready_event, self.vllm_config, self.kv_caches,
self._prefill_pp_layer_partition)
self.block_len, ready_event, self.vllm_config, self.kv_caches)
self.kv_recv_thread.start()
start_wait_time = time.time()
@@ -1494,31 +1471,3 @@ def ensure_zmq_recv(
raise RuntimeError(
f"Failed to receive data after {max_retries} "
f"retries: {e}")
# decode node should know pp_partition_layer in prefill node,
# it is configured in kv_transfer_config by partition_list_str,
# default using vllm layer split algorithm.
def get_prefill_pp_indices(
num_hidden_layers: int,
pp_rank: int,
pp_size: int,
partition_list_str: Optional[str] = None) -> tuple[int, int]:
if partition_list_str is None:
return get_pp_indices(num_hidden_layers, pp_rank, pp_size)
else:
try:
partitions = [
int(layer) for layer in partition_list_str.split(",")
]
except ValueError as err:
raise ValueError("Invalid partition string: {}".format(
partition_list_str)) from err
if len(partitions) != pp_size:
raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
if sum(partitions) != num_hidden_layers:
raise ValueError(
f"{sum(partitions)=} does not match {num_hidden_layers=}.")
start_layer = sum(partitions[:pp_rank])
end_layer = start_layer + partitions[pp_rank]
return (start_layer, end_layer)