[P/D][PCP] mooncake layerwise support pcp function (#6627)
### What this PR does / why we need it?
mooncake layerwise support pcp function
PCP (Prefill Context Parallelism) Support: Introduced explicit support
for Prefill Context Parallelism (PCP) and Decode Context Parallelism
(DCP) in the Mooncake layerwise KV cache transfer mechanism, allowing
for more granular control and awareness of parallel configurations
during data transfer.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
By ci
- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd
---------
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
This commit is contained in:
@@ -61,8 +61,6 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
|
||||
6000], # 2 * total_layers
|
||||
use_mla=True,
|
||||
block_len=[1024, 2048],
|
||||
decode_tp_size=1,
|
||||
first_kv_cache=self.first_kv_cache,
|
||||
k_buffer=self.fake_k_buffer,
|
||||
v_buffer=self.fake_v_buffer,
|
||||
resharding_stream=fake_resharding_stream,
|
||||
@@ -70,6 +68,9 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
|
||||
|
||||
self.req_meta_base = ReqMeta(
|
||||
local_block_ids=[5, 8],
|
||||
remote_tp_size = 8,
|
||||
remote_pcp_size = 1,
|
||||
remote_dcp_size = 1,
|
||||
token_ids=[1, 2, 3],
|
||||
remote_block_ids=[10, 20],
|
||||
remote_engine_id="remote_engine",
|
||||
@@ -112,8 +113,6 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
|
||||
kv_cache_base_addr=[1111, 2222, 3333, 4444],
|
||||
use_mla=False,
|
||||
block_len=[64],
|
||||
decode_tp_size=1,
|
||||
first_kv_cache=self.first_kv_cache,
|
||||
k_buffer=self.fake_k_buffer,
|
||||
v_buffer=self.fake_v_buffer,
|
||||
resharding_stream=fake_resharding_stream,
|
||||
@@ -242,12 +241,12 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase):
|
||||
th.task_tracker["reqX"] = 0
|
||||
th.request_map["reqX"] = "reqX"
|
||||
|
||||
th.update_task("reqX")
|
||||
th.update_task("reqX", 2)
|
||||
with th.lock:
|
||||
self.assertIn("reqX", th.task_tracker)
|
||||
self.assertNotIn("reqX", th.done_requests)
|
||||
|
||||
th.update_task("reqX")
|
||||
th.update_task("reqX", 2)
|
||||
with th.lock:
|
||||
self.assertNotIn("reqX", th.task_tracker)
|
||||
self.assertIn("reqX", th.done_requests)
|
||||
@@ -284,7 +283,7 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase):
|
||||
dec_inst = MagicMock()
|
||||
dec_inst.decode.side_effect = [
|
||||
(GET_META_MSG, ),
|
||||
(DONE_SENDING_MSG, "reqA"),
|
||||
(DONE_SENDING_MSG, "reqA", 1),
|
||||
(b"weird_msg", ),
|
||||
]
|
||||
mock_Decoder.return_value = dec_inst
|
||||
@@ -339,21 +338,11 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase):
|
||||
finished = th.get_and_clear_finished_requests()
|
||||
self.assertIn("reqA", finished)
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger"
|
||||
)
|
||||
@patch(
|
||||
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.get_ip",
|
||||
return_value="127.0.0.1")
|
||||
@patch(
|
||||
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Decoder"
|
||||
)
|
||||
@patch(
|
||||
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Encoder"
|
||||
)
|
||||
@patch(
|
||||
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.zmq_ctx"
|
||||
)
|
||||
@patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger")
|
||||
@patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.get_ip", return_value="127.0.0.1")
|
||||
@patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Decoder")
|
||||
@patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Encoder")
|
||||
@patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.zmq_ctx")
|
||||
def test_run_loop_pd_head_ratio_gt1_requires_multiple_done(
|
||||
self, mock_zmq_ctx, mock_Encoder, mock_Decoder, _mock_get_ip,
|
||||
_mock_logger):
|
||||
@@ -364,8 +353,8 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase):
|
||||
|
||||
dec_inst = MagicMock()
|
||||
dec_inst.decode.side_effect = [
|
||||
(DONE_SENDING_MSG, "reqB"),
|
||||
(DONE_SENDING_MSG, "reqB"),
|
||||
(DONE_SENDING_MSG, "reqB", 2),
|
||||
(DONE_SENDING_MSG, "reqB", 2),
|
||||
]
|
||||
mock_Decoder.return_value = dec_inst
|
||||
|
||||
@@ -373,25 +362,26 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase):
|
||||
sock.recv_multipart.side_effect = [
|
||||
[b"ID", b"PAY1"],
|
||||
[b"ID", b"PAY2"],
|
||||
SystemExit,
|
||||
SystemExit, # 退出循环
|
||||
]
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = sock
|
||||
mock_zmq_ctx.return_value = cm
|
||||
|
||||
th = KVCacheRecvingLayerThread(tp_rank=0,
|
||||
side_channel_port=5555,
|
||||
tp_size=2,
|
||||
pd_head_ratio=2,
|
||||
local_engine_id="engineY",
|
||||
metadata=self.meta,
|
||||
ready_event=self.ready_event)
|
||||
th = KVCacheRecvingLayerThread(
|
||||
tp_rank=0,
|
||||
side_channel_port=5555,
|
||||
tp_size=2,
|
||||
pd_head_ratio=2,
|
||||
local_engine_id="engineY",
|
||||
metadata=self.meta,
|
||||
ready_event=self.ready_event
|
||||
)
|
||||
with th.lock:
|
||||
th.task_tracker["reqB"] = 0
|
||||
th.request_map["reqB"] = "reqB"
|
||||
with self.assertRaises(SystemExit):
|
||||
th.run()
|
||||
|
||||
finished = th.get_and_clear_finished_requests()
|
||||
self.assertIn("reqB", finished)
|
||||
|
||||
@@ -441,6 +431,7 @@ class MockRequest:
|
||||
self.kv_transfer_params = kv_transfer_params or {}
|
||||
self.status = status or "running"
|
||||
self.output_token_ids = [101, 102]
|
||||
self.num_computed_tokens = 0
|
||||
|
||||
self.all_token_ids = list(self.prompt_token_ids)
|
||||
|
||||
@@ -565,7 +556,8 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
|
||||
req = MockRequest("req_u1",
|
||||
prompt_token_ids=list(range(24)),
|
||||
kv_transfer_params={"do_remote_prefill": True})
|
||||
blocks = _MockBlocks(unhashed=[4, 5, 6])
|
||||
req.num_computed_tokens = 0
|
||||
blocks = _MockBlocks(unhashed=[4, 5, 6], block_ids_tuple=([4, 5, 6], ))
|
||||
|
||||
self.scheduler.update_state_after_alloc(req,
|
||||
blocks,
|
||||
@@ -592,7 +584,6 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
|
||||
info = self.scheduler._reqs_need_send_layerwise["req_u2"]
|
||||
self.assertEqual(info.local_block_ids, [7, 8, 9])
|
||||
self.assertIs(info.request, req)
|
||||
self.assertEqual(info.remote_block_ids, [])
|
||||
|
||||
def test_build_connector_meta_consumes_reqs_need_recv_and_clears(self):
|
||||
self.scheduler.vllm_config.kv_transfer_config.is_kv_consumer = True
|
||||
@@ -663,12 +654,13 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
|
||||
send_req_info.update_computed_tokens = MagicMock()
|
||||
send_req_info.update_transferred_tokens = MagicMock()
|
||||
send_req_info.unpack = MagicMock(
|
||||
return_value=(send_req_info.local_block_ids,
|
||||
send_req_info.remote_block_ids,
|
||||
send_req_info.remote_cache_tokens,
|
||||
send_req_info.local_transferred_tokens,
|
||||
send_req_info.local_computed_tokens,
|
||||
send_req_info.request))
|
||||
return_value=(
|
||||
send_req_info.local_block_ids,
|
||||
send_req_info.local_transferred_tokens,
|
||||
send_req_info.local_computed_tokens,
|
||||
send_req_info.request
|
||||
)
|
||||
)
|
||||
|
||||
self.scheduler._reqs_need_send_layerwise["req_b3"] = send_req_info
|
||||
out = _MockSchedulerOutput(
|
||||
@@ -920,6 +912,11 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase):
|
||||
self.vllm_config = MockVllmConfig()
|
||||
self.engine_id = "test_engine"
|
||||
self.kv_caches = {"layer1": (MagicMock(), MagicMock())}
|
||||
self.vllm_config.parallel_config.tensor_parallel_size = 1
|
||||
self.vllm_config.parallel_config.prefill_context_parallel_size = 1
|
||||
self.vllm_config.parallel_config.decode_context_parallel_size = 1
|
||||
self.vllm_config.parallel_config.data_parallel_rank = 0
|
||||
self.vllm_config.kv_transfer_config.kv_port = 1234
|
||||
|
||||
def tearDown(self):
|
||||
for p in self.patches:
|
||||
@@ -956,4 +953,4 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase):
|
||||
self.engine_id)
|
||||
worker.register_kv_caches(mla_caches)
|
||||
self.assertTrue(worker.use_mla)
|
||||
self.assertEqual(len(worker.block_len), 2)
|
||||
self.assertEqual(len(worker.block_len), 2)
|
||||
@@ -8,14 +8,11 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
|
||||
ModelConfig, SchedulerConfig, VllmConfig)
|
||||
from vllm.config import CacheConfig, DeviceConfig, KVTransferConfig, ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.utils.hashing import sha256
|
||||
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
||||
init_none_hash)
|
||||
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
@@ -37,14 +34,10 @@ def assert_scheduler_empty(scheduler: Scheduler):
|
||||
assert len(scheduler.encoder_cache_manager.cached) == 0
|
||||
|
||||
# KVCache Manager.
|
||||
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
req_to_blocks) == 0
|
||||
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
num_cached_block) == 0
|
||||
num_free_blocks = (
|
||||
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
|
||||
assert num_free_blocks == (
|
||||
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
|
||||
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks) == 0
|
||||
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].num_cached_block) == 0
|
||||
num_free_blocks = scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks
|
||||
assert num_free_blocks == (scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
|
||||
|
||||
# NOTE(rob): just the ref count on blocks will be 0. The hash
|
||||
# value, etc will remain since we lazily evict for prefix cache.
|
||||
@@ -63,8 +56,7 @@ def create_vllm_config(
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_model_len=max_num_batched_tokens,
|
||||
)
|
||||
fake_weight_path = os.path.join(os.path.dirname(__file__), "..",
|
||||
"fake_weight")
|
||||
fake_weight_path = os.path.join(os.path.dirname(__file__), "..", "fake_weight")
|
||||
model_config = ModelConfig(
|
||||
model=fake_weight_path,
|
||||
skip_tokenizer_init=True,
|
||||
@@ -77,14 +69,14 @@ def create_vllm_config(
|
||||
cache_dtype="auto",
|
||||
enable_prefix_caching=True,
|
||||
)
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="MooncakeConnectorV1",
|
||||
kv_role="kv_both")
|
||||
return VllmConfig(scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
device_config=DeviceConfig("cpu"))
|
||||
kv_transfer_config = KVTransferConfig(kv_connector="MooncakeConnectorV1", kv_role="kv_both")
|
||||
return VllmConfig(
|
||||
scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
device_config=DeviceConfig("cpu"),
|
||||
)
|
||||
|
||||
|
||||
def create_scheduler(
|
||||
@@ -96,11 +88,7 @@ def create_scheduler(
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(['layer'],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float16,
|
||||
False, False))
|
||||
],
|
||||
kv_cache_groups=[KVCacheGroupSpec(["layer"], FullAttentionSpec(block_size, 1, 1, torch.float16, False, False))],
|
||||
)
|
||||
vllm_config.cache_config.num_gpu_blocks = num_blocks
|
||||
|
||||
@@ -138,19 +126,19 @@ def create_request(
|
||||
|
||||
if do_remote_decode:
|
||||
assert not do_remote_prefill
|
||||
kv_transfer_params = dict(do_remote_prefill=False,
|
||||
do_remote_decode=True)
|
||||
kv_transfer_params = dict(do_remote_prefill=False, do_remote_decode=True)
|
||||
elif do_remote_prefill:
|
||||
kv_transfer_params = dict(do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
remote_engine_id="my-engine-id",
|
||||
remote_block_ids=list(
|
||||
range(num_remote_blocks)),
|
||||
remote_host="my-host",
|
||||
remote_port=1234,
|
||||
remote_tp_size=1,
|
||||
remote_pcp_size=1,
|
||||
remote_dcp_size=1)
|
||||
kv_transfer_params = dict(
|
||||
do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
remote_engine_id="my-engine-id",
|
||||
remote_block_ids=list(range(num_remote_blocks)),
|
||||
remote_host="my-host",
|
||||
remote_port=1234,
|
||||
remote_tp_size=1,
|
||||
remote_pcp_size=1,
|
||||
remote_dcp_size=1,
|
||||
)
|
||||
|
||||
max_tokens = 1 if do_remote_decode else max_tokens
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens)
|
||||
@@ -190,10 +178,9 @@ def create_model_runner_output(
|
||||
|
||||
# Make output data structure.
|
||||
extra_args = {}
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
||||
KVConnectorOutput # type: ignore # noqa
|
||||
kv_connector_output = KVConnectorOutput(finished_sending=finished_sending,
|
||||
finished_recving=finished_recving)
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput # type: ignore # noqa
|
||||
|
||||
kv_connector_output = KVConnectorOutput(finished_sending=finished_sending, finished_recving=finished_recving)
|
||||
extra_args = {"kv_connector_output": kv_connector_output}
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
|
||||
Reference in New Issue
Block a user