This reverts commit332b547728. This break deepseek3.2 in PD case. - vLLM version: v0.12.0 - vLLM main:ad32e3e19c
This commit is contained in:
@@ -242,8 +242,7 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event,
|
||||
vllm_config=self.vllm_config,
|
||||
kv_caches=self.kv_caches,
|
||||
prefill_pp_layer_partition=None)
|
||||
kv_caches=self.kv_caches)
|
||||
|
||||
def test_add_request(self):
|
||||
test_req = {
|
||||
@@ -296,8 +295,7 @@ class TestSocketManagement(unittest.TestCase):
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event,
|
||||
vllm_config=self.vllm_config,
|
||||
kv_caches=self.kv_caches,
|
||||
prefill_pp_layer_partition=None)
|
||||
kv_caches=self.kv_caches)
|
||||
self.thread.remote_sockets = defaultdict(deque)
|
||||
self.thread.remote_poller = MagicMock()
|
||||
|
||||
@@ -354,8 +352,7 @@ class TestCoreFunctionality(unittest.TestCase):
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event,
|
||||
vllm_config=self.vllm_config,
|
||||
kv_caches=self.kv_caches,
|
||||
prefill_pp_layer_partition=None)
|
||||
kv_caches=self.kv_caches)
|
||||
self.thread.request_queue = self.mock_queue
|
||||
self.test_req = {
|
||||
"request_id": "req1",
|
||||
@@ -437,8 +434,7 @@ class TestMetadataHandling(unittest.TestCase):
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event,
|
||||
vllm_config=self.vllm_config,
|
||||
kv_caches=self.kv_caches,
|
||||
prefill_pp_layer_partition=None)
|
||||
kv_caches=self.kv_caches)
|
||||
self.test_metadata = MooncakeAgentMetadata(
|
||||
engine_id="remote_engine",
|
||||
te_rpc_port=9090,
|
||||
@@ -502,8 +498,7 @@ class TestMainThreadLoop(unittest.TestCase):
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event,
|
||||
vllm_config=self.vllm_config,
|
||||
kv_caches=self.kv_caches,
|
||||
prefill_pp_layer_partition=None)
|
||||
kv_caches=self.kv_caches)
|
||||
self.thread.request_queue = queue.Queue()
|
||||
|
||||
@patch.object(KVCacheRecvingThread, '_handle_request')
|
||||
@@ -540,7 +535,6 @@ class MockVllmConfig:
|
||||
self.parallel_config = MagicMock()
|
||||
self.cache_config = MagicMock()
|
||||
self.kv_transfer_config = MagicMock()
|
||||
self.speculative_config = MagicMock()
|
||||
self.model_config.use_mla = True
|
||||
self.parallel_config.tensor_parallel_size = 2
|
||||
self.parallel_config.data_parallel_rank = 0
|
||||
|
||||
@@ -276,19 +276,12 @@ class KVCacheSendingThread(threading.Thread):
|
||||
|
||||
class KVCacheRecvingThread(threading.Thread):
|
||||
|
||||
def __init__(self,
|
||||
tp_rank: int,
|
||||
tp_size: int,
|
||||
_prefill_pp_size: int,
|
||||
engine: TransferEngine,
|
||||
local_engine_id: str,
|
||||
def __init__(self, tp_rank: int, tp_size: int, _prefill_pp_size: int,
|
||||
engine: TransferEngine, local_engine_id: str,
|
||||
local_handshake_port: int,
|
||||
local_kv_caches_base_addr: list[int],
|
||||
block_len: list[int],
|
||||
ready_event: threading.Event,
|
||||
vllm_config: VllmConfig,
|
||||
kv_caches: dict[str, Any],
|
||||
prefill_pp_layer_partition: Optional[str] = None):
|
||||
local_kv_caches_base_addr: list[int], block_len: list[int],
|
||||
ready_event: threading.Event, vllm_config: VllmConfig,
|
||||
kv_caches: dict[str, Any]):
|
||||
super().__init__(daemon=True, name="KVCacheRecvingThread")
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = tp_size
|
||||
@@ -327,14 +320,6 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = self.vllm_config.model_config
|
||||
self.block_size = self.vllm_config.cache_config.block_size
|
||||
self.num_layers = self.model_config.hf_config.num_hidden_layers
|
||||
self.pp_layer_indices = {
|
||||
rank:
|
||||
get_prefill_pp_indices(self.num_layers, rank,
|
||||
self._prefill_pp_size,
|
||||
prefill_pp_layer_partition)
|
||||
for rank in range(self._prefill_pp_size)
|
||||
}
|
||||
if self.use_mla:
|
||||
self.k_head_dim = self.model_config.hf_config.kv_lora_rank
|
||||
self.v_head_dim = self.model_config.hf_config.qk_rope_head_dim
|
||||
@@ -455,14 +440,9 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
|
||||
remote_kv_caches_base_addrs = \
|
||||
self.kv_caches_base_addr[remote_engine_id][remote_handshake_port]
|
||||
first_layer_index, end_layer_index = self.pp_layer_indices[
|
||||
prefill_pp_rank]
|
||||
# support MTP layer kv transfer
|
||||
if self.vllm_config.speculative_config is not None:
|
||||
num_speculative_tokens = self.vllm_config.speculative_config.num_speculative_tokens
|
||||
num_speculative_tokens = 0 if num_speculative_tokens is None else num_speculative_tokens
|
||||
if prefill_pp_rank == self._prefill_pp_size - 1:
|
||||
end_layer_index = end_layer_index + num_speculative_tokens
|
||||
num_layers = self.model_config.hf_config.num_hidden_layers
|
||||
first_layer_index, end_layer_index = get_pp_indices(
|
||||
num_layers, prefill_pp_rank, self._prefill_pp_size)
|
||||
num_cache_per_layer = len(list(
|
||||
self.kv_caches.values())[0]) # Number of KV caches per layer
|
||||
local_kv_caches_base_addrs = \
|
||||
@@ -1045,8 +1025,6 @@ class MooncakeConnectorWorker:
|
||||
# get prefill pp size from extra config
|
||||
self._decode_pp_size = decode_parallel_config.get("pp_size", 1)
|
||||
assert self._decode_pp_size == 1, "decode pp size must be 1"
|
||||
self._prefill_pp_layer_partition = prefill_parallel_config.get(
|
||||
"pp_layer_partition", None)
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""Register the KV Cache data."""
|
||||
@@ -1153,8 +1131,7 @@ class MooncakeConnectorWorker:
|
||||
self.kv_recv_thread = KVCacheRecvingThread(
|
||||
self.tp_rank, self.tp_size, self._prefill_pp_size, self.engine,
|
||||
self.engine_id, self.handshake_port, kv_caches_base_addr,
|
||||
self.block_len, ready_event, self.vllm_config, self.kv_caches,
|
||||
self._prefill_pp_layer_partition)
|
||||
self.block_len, ready_event, self.vllm_config, self.kv_caches)
|
||||
self.kv_recv_thread.start()
|
||||
|
||||
start_wait_time = time.time()
|
||||
@@ -1494,31 +1471,3 @@ def ensure_zmq_recv(
|
||||
raise RuntimeError(
|
||||
f"Failed to receive data after {max_retries} "
|
||||
f"retries: {e}")
|
||||
|
||||
|
||||
# decode node should know pp_partition_layer in prefill node,
|
||||
# it is configured in kv_transfer_config by partition_list_str,
|
||||
# default using vllm layer split algorithm.
|
||||
def get_prefill_pp_indices(
|
||||
num_hidden_layers: int,
|
||||
pp_rank: int,
|
||||
pp_size: int,
|
||||
partition_list_str: Optional[str] = None) -> tuple[int, int]:
|
||||
if partition_list_str is None:
|
||||
return get_pp_indices(num_hidden_layers, pp_rank, pp_size)
|
||||
else:
|
||||
try:
|
||||
partitions = [
|
||||
int(layer) for layer in partition_list_str.split(",")
|
||||
]
|
||||
except ValueError as err:
|
||||
raise ValueError("Invalid partition string: {}".format(
|
||||
partition_list_str)) from err
|
||||
if len(partitions) != pp_size:
|
||||
raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
|
||||
if sum(partitions) != num_hidden_layers:
|
||||
raise ValueError(
|
||||
f"{sum(partitions)=} does not match {num_hidden_layers=}.")
|
||||
start_layer = sum(partitions[:pp_rank])
|
||||
end_layer = start_layer + partitions[pp_rank]
|
||||
return (start_layer, end_layer)
|
||||
|
||||
Reference in New Issue
Block a user