[Bugfix] support mtp kv transfer and pp partition by hand in kv transfer (#4892)
### What this PR does / why we need it?
Current mooncake connector has following problems with PP and MTP
enabled:
1. MTP layer kv caches are not transfered, it may cause decreasing of
accept ratio: This PR add MTP layer indices for last PP stage after
calculating end_layer in transfer_kv_cache
2. While MTP enabled, PP layers divided by default may cause imbalance
between stages, we need to use `VLLM_PP_LAYER_PARTITION` environment to
make it balance by hand, but in mooncake connector kv transfer, decode
doesn't know the partition of prefill node: This PR add config
`pp_layer_partition` in `kv_connector_extra_config` to make decode node
acquire the partition information of prefill node.
### Does this PR introduce _any_ user-facing change?
When prefill using `VLLM_PP_LAYER_PARTITION` environment, add
`pp_layer_partition` in `kv_connector_extra_config` like below:
```
export VLLM_PP_LAYER_PARTITION=33,28
"kv_connector_extra_config": {
"use_ascend_direct": true,
"prefill": {
"dp_size": 1,
"tp_size": 8,
"pp_size": 2,
"pp_layer_partition": "33,28"
},
"decode": {
"dp_size": 16,
"tp_size": 1,
"pp_size": 1
}
}
```
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: lidenghui <lidenghui1110@gmail.com>
This commit is contained in:
@@ -242,7 +242,8 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event,
|
||||
vllm_config=self.vllm_config,
|
||||
kv_caches=self.kv_caches)
|
||||
kv_caches=self.kv_caches,
|
||||
prefill_pp_layer_partition=None)
|
||||
|
||||
def test_add_request(self):
|
||||
test_req = {
|
||||
@@ -295,7 +296,8 @@ class TestSocketManagement(unittest.TestCase):
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event,
|
||||
vllm_config=self.vllm_config,
|
||||
kv_caches=self.kv_caches)
|
||||
kv_caches=self.kv_caches,
|
||||
prefill_pp_layer_partition=None)
|
||||
self.thread.remote_sockets = defaultdict(deque)
|
||||
self.thread.remote_poller = MagicMock()
|
||||
|
||||
@@ -352,7 +354,8 @@ class TestCoreFunctionality(unittest.TestCase):
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event,
|
||||
vllm_config=self.vllm_config,
|
||||
kv_caches=self.kv_caches)
|
||||
kv_caches=self.kv_caches,
|
||||
prefill_pp_layer_partition=None)
|
||||
self.thread.request_queue = self.mock_queue
|
||||
self.test_req = {
|
||||
"request_id": "req1",
|
||||
@@ -434,7 +437,8 @@ class TestMetadataHandling(unittest.TestCase):
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event,
|
||||
vllm_config=self.vllm_config,
|
||||
kv_caches=self.kv_caches)
|
||||
kv_caches=self.kv_caches,
|
||||
prefill_pp_layer_partition=None)
|
||||
self.test_metadata = MooncakeAgentMetadata(
|
||||
engine_id="remote_engine",
|
||||
te_rpc_port=9090,
|
||||
@@ -498,7 +502,8 @@ class TestMainThreadLoop(unittest.TestCase):
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event,
|
||||
vllm_config=self.vllm_config,
|
||||
kv_caches=self.kv_caches)
|
||||
kv_caches=self.kv_caches,
|
||||
prefill_pp_layer_partition=None)
|
||||
self.thread.request_queue = queue.Queue()
|
||||
|
||||
@patch.object(KVCacheRecvingThread, '_handle_request')
|
||||
@@ -535,6 +540,7 @@ class MockVllmConfig:
|
||||
self.parallel_config = MagicMock()
|
||||
self.cache_config = MagicMock()
|
||||
self.kv_transfer_config = MagicMock()
|
||||
self.speculative_config = MagicMock()
|
||||
self.model_config.use_mla = True
|
||||
self.parallel_config.tensor_parallel_size = 2
|
||||
self.parallel_config.data_parallel_rank = 0
|
||||
|
||||
@@ -271,12 +271,19 @@ class KVCacheSendingThread(threading.Thread):
|
||||
|
||||
class KVCacheRecvingThread(threading.Thread):
|
||||
|
||||
def __init__(self, tp_rank: int, tp_size: int, _prefill_pp_size: int,
|
||||
engine: TransferEngine, local_engine_id: str,
|
||||
def __init__(self,
|
||||
tp_rank: int,
|
||||
tp_size: int,
|
||||
_prefill_pp_size: int,
|
||||
engine: TransferEngine,
|
||||
local_engine_id: str,
|
||||
local_handshake_port: int,
|
||||
local_kv_caches_base_addr: list[int], block_len: list[int],
|
||||
ready_event: threading.Event, vllm_config: VllmConfig,
|
||||
kv_caches: dict[str, Any]):
|
||||
local_kv_caches_base_addr: list[int],
|
||||
block_len: list[int],
|
||||
ready_event: threading.Event,
|
||||
vllm_config: VllmConfig,
|
||||
kv_caches: dict[str, Any],
|
||||
prefill_pp_layer_partition: Optional[str] = None):
|
||||
super().__init__(daemon=True, name="KVCacheRecvingThread")
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = tp_size
|
||||
@@ -315,6 +322,14 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = self.vllm_config.model_config
|
||||
self.block_size = self.vllm_config.cache_config.block_size
|
||||
self.num_layers = self.model_config.hf_config.num_hidden_layers
|
||||
self.pp_layer_indices = {
|
||||
rank:
|
||||
get_prefill_pp_indices(self.num_layers, rank,
|
||||
self._prefill_pp_size,
|
||||
prefill_pp_layer_partition)
|
||||
for rank in range(self._prefill_pp_size)
|
||||
}
|
||||
if self.use_mla:
|
||||
self.k_head_dim = self.model_config.hf_config.kv_lora_rank
|
||||
self.v_head_dim = self.model_config.hf_config.qk_rope_head_dim
|
||||
@@ -435,9 +450,14 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
|
||||
remote_kv_caches_base_addrs = \
|
||||
self.kv_caches_base_addr[remote_engine_id][remote_handshake_port]
|
||||
num_layers = self.model_config.hf_config.num_hidden_layers
|
||||
first_layer_index, end_layer_index = get_pp_indices(
|
||||
num_layers, prefill_pp_rank, self._prefill_pp_size)
|
||||
first_layer_index, end_layer_index = self.pp_layer_indices[
|
||||
prefill_pp_rank]
|
||||
# support MTP layer kv transfer
|
||||
if self.vllm_config.speculative_config is not None:
|
||||
num_speculative_tokens = self.vllm_config.speculative_config.num_speculative_tokens
|
||||
num_speculative_tokens = 0 if num_speculative_tokens is None else num_speculative_tokens
|
||||
if prefill_pp_rank == self._prefill_pp_size - 1:
|
||||
end_layer_index = end_layer_index + num_speculative_tokens
|
||||
num_cache_per_layer = len(list(
|
||||
self.kv_caches.values())[0]) # Number of KV caches per layer
|
||||
local_kv_caches_base_addrs = \
|
||||
@@ -1020,6 +1040,8 @@ class MooncakeConnectorWorker:
|
||||
# get prefill pp size from extra config
|
||||
self._decode_pp_size = decode_parallel_config.get("pp_size", 1)
|
||||
assert self._decode_pp_size == 1, "decode pp size must be 1"
|
||||
self._prefill_pp_layer_partition = prefill_parallel_config.get(
|
||||
"pp_layer_partition", None)
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""Register the KV Cache data."""
|
||||
@@ -1126,7 +1148,8 @@ class MooncakeConnectorWorker:
|
||||
self.kv_recv_thread = KVCacheRecvingThread(
|
||||
self.tp_rank, self.tp_size, self._prefill_pp_size, self.engine,
|
||||
self.engine_id, self.handshake_port, kv_caches_base_addr,
|
||||
self.block_len, ready_event, self.vllm_config, self.kv_caches)
|
||||
self.block_len, ready_event, self.vllm_config, self.kv_caches,
|
||||
self._prefill_pp_layer_partition)
|
||||
self.kv_recv_thread.start()
|
||||
ready_event.wait()
|
||||
|
||||
@@ -1455,3 +1478,31 @@ def ensure_zmq_recv(
|
||||
raise RuntimeError(
|
||||
f"Failed to receive data after {max_retries} "
|
||||
f"retries: {e}")
|
||||
|
||||
|
||||
# decode node should know pp_partition_layer in prefill node,
|
||||
# it is configured in kv_transfer_config by partition_list_str,
|
||||
# default using vllm layer split algorithm.
|
||||
def get_prefill_pp_indices(
|
||||
num_hidden_layers: int,
|
||||
pp_rank: int,
|
||||
pp_size: int,
|
||||
partition_list_str: Optional[str] = None) -> tuple[int, int]:
|
||||
if partition_list_str is None:
|
||||
return get_pp_indices(num_hidden_layers, pp_rank, pp_size)
|
||||
else:
|
||||
try:
|
||||
partitions = [
|
||||
int(layer) for layer in partition_list_str.split(",")
|
||||
]
|
||||
except ValueError as err:
|
||||
raise ValueError("Invalid partition string: {}".format(
|
||||
partition_list_str)) from err
|
||||
if len(partitions) != pp_size:
|
||||
raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
|
||||
if sum(partitions) != num_hidden_layers:
|
||||
raise ValueError(
|
||||
f"{sum(partitions)=} does not match {num_hidden_layers=}.")
|
||||
start_layer = sum(partitions[:pp_rank])
|
||||
end_layer = start_layer + partitions[pp_rank]
|
||||
return (start_layer, end_layer)
|
||||
|
||||
Reference in New Issue
Block a user