[P/D][PCP] mooncake layerwise support pcp function (#6627)

### What this PR does / why we need it?
mooncake layerwise support pcp function
PCP (Prefill Context Parallelism) Support: Introduced explicit support
for Prefill Context Parallelism (PCP) and Decode Context Parallelism
(DCP) in the Mooncake layerwise KV cache transfer mechanism, allowing
for more granular control and awareness of parallel configurations
during data transfer.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
By ci

- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd

---------

Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
This commit is contained in:
wangxiaoteng888
2026-02-12 11:02:25 +08:00
committed by GitHub
parent 8b23554741
commit b881fab416
7 changed files with 551 additions and 223 deletions

View File

@@ -232,15 +232,17 @@ class TestAscendAttentionCPImpl(TestBase):
self.assertEqual(value.shape[1], num_heads)
self.assertEqual(value.shape[2], head_size)
@patch('torch_npu.Event', create=True)
@patch('torch_npu._npu_reshape_and_cache')
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
def test_reshape_and_cache(self, mock_npu_reshape_and_cache):
def test_reshape_and_cache(self, mock_event_class, mock_npu_reshape_and_cache):
num_tokens = 4
block_num = 100
block_size = 128
num_heads = 1
head_size = 128
self.impl.head_size = head_size
self.impl.is_kv_producer = False
kv_cache = (torch.randn(block_num, block_size, num_heads, head_size),
torch.randn(block_num, block_size, num_heads, head_size))

View File

@@ -61,8 +61,6 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
6000], # 2 * total_layers
use_mla=True,
block_len=[1024, 2048],
decode_tp_size=1,
first_kv_cache=self.first_kv_cache,
k_buffer=self.fake_k_buffer,
v_buffer=self.fake_v_buffer,
resharding_stream=fake_resharding_stream,
@@ -70,6 +68,9 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
self.req_meta_base = ReqMeta(
local_block_ids=[5, 8],
remote_tp_size = 8,
remote_pcp_size = 1,
remote_dcp_size = 1,
token_ids=[1, 2, 3],
remote_block_ids=[10, 20],
remote_engine_id="remote_engine",
@@ -112,8 +113,6 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
kv_cache_base_addr=[1111, 2222, 3333, 4444],
use_mla=False,
block_len=[64],
decode_tp_size=1,
first_kv_cache=self.first_kv_cache,
k_buffer=self.fake_k_buffer,
v_buffer=self.fake_v_buffer,
resharding_stream=fake_resharding_stream,
@@ -242,12 +241,12 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase):
th.task_tracker["reqX"] = 0
th.request_map["reqX"] = "reqX"
th.update_task("reqX")
th.update_task("reqX", 2)
with th.lock:
self.assertIn("reqX", th.task_tracker)
self.assertNotIn("reqX", th.done_requests)
th.update_task("reqX")
th.update_task("reqX", 2)
with th.lock:
self.assertNotIn("reqX", th.task_tracker)
self.assertIn("reqX", th.done_requests)
@@ -284,7 +283,7 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase):
dec_inst = MagicMock()
dec_inst.decode.side_effect = [
(GET_META_MSG, ),
(DONE_SENDING_MSG, "reqA"),
(DONE_SENDING_MSG, "reqA", 1),
(b"weird_msg", ),
]
mock_Decoder.return_value = dec_inst
@@ -339,21 +338,11 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase):
finished = th.get_and_clear_finished_requests()
self.assertIn("reqA", finished)
@patch(
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger"
)
@patch(
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.get_ip",
return_value="127.0.0.1")
@patch(
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Decoder"
)
@patch(
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Encoder"
)
@patch(
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.zmq_ctx"
)
@patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger")
@patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.get_ip", return_value="127.0.0.1")
@patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Decoder")
@patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Encoder")
@patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.zmq_ctx")
def test_run_loop_pd_head_ratio_gt1_requires_multiple_done(
self, mock_zmq_ctx, mock_Encoder, mock_Decoder, _mock_get_ip,
_mock_logger):
@@ -364,8 +353,8 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase):
dec_inst = MagicMock()
dec_inst.decode.side_effect = [
(DONE_SENDING_MSG, "reqB"),
(DONE_SENDING_MSG, "reqB"),
(DONE_SENDING_MSG, "reqB", 2),
(DONE_SENDING_MSG, "reqB", 2),
]
mock_Decoder.return_value = dec_inst
@@ -373,25 +362,26 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase):
sock.recv_multipart.side_effect = [
[b"ID", b"PAY1"],
[b"ID", b"PAY2"],
SystemExit,
SystemExit, # 退出循环
]
cm = MagicMock()
cm.__enter__.return_value = sock
mock_zmq_ctx.return_value = cm
th = KVCacheRecvingLayerThread(tp_rank=0,
side_channel_port=5555,
tp_size=2,
pd_head_ratio=2,
local_engine_id="engineY",
metadata=self.meta,
ready_event=self.ready_event)
th = KVCacheRecvingLayerThread(
tp_rank=0,
side_channel_port=5555,
tp_size=2,
pd_head_ratio=2,
local_engine_id="engineY",
metadata=self.meta,
ready_event=self.ready_event
)
with th.lock:
th.task_tracker["reqB"] = 0
th.request_map["reqB"] = "reqB"
with self.assertRaises(SystemExit):
th.run()
finished = th.get_and_clear_finished_requests()
self.assertIn("reqB", finished)
@@ -441,6 +431,7 @@ class MockRequest:
self.kv_transfer_params = kv_transfer_params or {}
self.status = status or "running"
self.output_token_ids = [101, 102]
self.num_computed_tokens = 0
self.all_token_ids = list(self.prompt_token_ids)
@@ -565,7 +556,8 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
req = MockRequest("req_u1",
prompt_token_ids=list(range(24)),
kv_transfer_params={"do_remote_prefill": True})
blocks = _MockBlocks(unhashed=[4, 5, 6])
req.num_computed_tokens = 0
blocks = _MockBlocks(unhashed=[4, 5, 6], block_ids_tuple=([4, 5, 6], ))
self.scheduler.update_state_after_alloc(req,
blocks,
@@ -592,7 +584,6 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
info = self.scheduler._reqs_need_send_layerwise["req_u2"]
self.assertEqual(info.local_block_ids, [7, 8, 9])
self.assertIs(info.request, req)
self.assertEqual(info.remote_block_ids, [])
def test_build_connector_meta_consumes_reqs_need_recv_and_clears(self):
self.scheduler.vllm_config.kv_transfer_config.is_kv_consumer = True
@@ -663,12 +654,13 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
send_req_info.update_computed_tokens = MagicMock()
send_req_info.update_transferred_tokens = MagicMock()
send_req_info.unpack = MagicMock(
return_value=(send_req_info.local_block_ids,
send_req_info.remote_block_ids,
send_req_info.remote_cache_tokens,
send_req_info.local_transferred_tokens,
send_req_info.local_computed_tokens,
send_req_info.request))
return_value=(
send_req_info.local_block_ids,
send_req_info.local_transferred_tokens,
send_req_info.local_computed_tokens,
send_req_info.request
)
)
self.scheduler._reqs_need_send_layerwise["req_b3"] = send_req_info
out = _MockSchedulerOutput(
@@ -920,6 +912,11 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase):
self.vllm_config = MockVllmConfig()
self.engine_id = "test_engine"
self.kv_caches = {"layer1": (MagicMock(), MagicMock())}
self.vllm_config.parallel_config.tensor_parallel_size = 1
self.vllm_config.parallel_config.prefill_context_parallel_size = 1
self.vllm_config.parallel_config.decode_context_parallel_size = 1
self.vllm_config.parallel_config.data_parallel_rank = 0
self.vllm_config.kv_transfer_config.kv_port = 1234
def tearDown(self):
for p in self.patches:
@@ -956,4 +953,4 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase):
self.engine_id)
worker.register_kv_caches(mla_caches)
self.assertTrue(worker.use_mla)
self.assertEqual(len(worker.block_len), 2)
self.assertEqual(len(worker.block_len), 2)

View File

@@ -8,14 +8,11 @@ from typing import Any, Optional
import torch
from vllm import SamplingParams
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
ModelConfig, SchedulerConfig, VllmConfig)
from vllm.config import CacheConfig, DeviceConfig, KVTransferConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.utils.hashing import sha256
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash)
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request
from vllm.v1.structured_output import StructuredOutputManager
@@ -37,14 +34,10 @@ def assert_scheduler_empty(scheduler: Scheduler):
assert len(scheduler.encoder_cache_manager.cached) == 0
# KVCache Manager.
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block) == 0
num_free_blocks = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
assert num_free_blocks == (
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].num_cached_block) == 0
num_free_blocks = scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks
assert num_free_blocks == (scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
# NOTE(rob): just the ref count on blocks will be 0. The hash
# value, etc will remain since we lazily evict for prefix cache.
@@ -63,8 +56,7 @@ def create_vllm_config(
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_num_batched_tokens,
)
fake_weight_path = os.path.join(os.path.dirname(__file__), "..",
"fake_weight")
fake_weight_path = os.path.join(os.path.dirname(__file__), "..", "fake_weight")
model_config = ModelConfig(
model=fake_weight_path,
skip_tokenizer_init=True,
@@ -77,14 +69,14 @@ def create_vllm_config(
cache_dtype="auto",
enable_prefix_caching=True,
)
kv_transfer_config = KVTransferConfig(
kv_connector="MooncakeConnectorV1",
kv_role="kv_both")
return VllmConfig(scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
device_config=DeviceConfig("cpu"))
kv_transfer_config = KVTransferConfig(kv_connector="MooncakeConnectorV1", kv_role="kv_both")
return VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
device_config=DeviceConfig("cpu"),
)
def create_scheduler(
@@ -96,11 +88,7 @@ def create_scheduler(
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float16,
False, False))
],
kv_cache_groups=[KVCacheGroupSpec(["layer"], FullAttentionSpec(block_size, 1, 1, torch.float16, False, False))],
)
vllm_config.cache_config.num_gpu_blocks = num_blocks
@@ -138,19 +126,19 @@ def create_request(
if do_remote_decode:
assert not do_remote_prefill
kv_transfer_params = dict(do_remote_prefill=False,
do_remote_decode=True)
kv_transfer_params = dict(do_remote_prefill=False, do_remote_decode=True)
elif do_remote_prefill:
kv_transfer_params = dict(do_remote_prefill=True,
do_remote_decode=False,
remote_engine_id="my-engine-id",
remote_block_ids=list(
range(num_remote_blocks)),
remote_host="my-host",
remote_port=1234,
remote_tp_size=1,
remote_pcp_size=1,
remote_dcp_size=1)
kv_transfer_params = dict(
do_remote_prefill=True,
do_remote_decode=False,
remote_engine_id="my-engine-id",
remote_block_ids=list(range(num_remote_blocks)),
remote_host="my-host",
remote_port=1234,
remote_tp_size=1,
remote_pcp_size=1,
remote_dcp_size=1,
)
max_tokens = 1 if do_remote_decode else max_tokens
sampling_params = SamplingParams(max_tokens=max_tokens)
@@ -190,10 +178,9 @@ def create_model_runner_output(
# Make output data structure.
extra_args = {}
from vllm.v1.worker.kv_connector_model_runner_mixin import \
KVConnectorOutput # type: ignore # noqa
kv_connector_output = KVConnectorOutput(finished_sending=finished_sending,
finished_recving=finished_recving)
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput # type: ignore # noqa
kv_connector_output = KVConnectorOutput(finished_sending=finished_sending, finished_recving=finished_recving)
extra_args = {"kv_connector_output": kv_connector_output}
model_runner_output = ModelRunnerOutput(