[Refactor]refactor p2p connector (#6551)

### What this PR does / why we need it?
Redundant code is removed, and repeated logic is combined through the
p2p connector refactor, making the code easy to extend.

### Does this PR introduce _any_ user-facing change?
NA

### How was this patch tested?
P节点:
```
vllm serve /mnt/weight/DeepSeek-V3.2-Exp-W8A8 \
  --host 0.0.0.0 \
  --port 8002 \
  --data-parallel-size 2 \
  --tensor-parallel-size 8 \
  --enable-expert-parallel \
  --seed 1024 \
  --served-model-name model \
  --max-model-len 8192 \
  --max-num-batched-tokens 8192 \
  --max-num-seqs 16 \
  --enforce-eager \
  --trust-remote-code \
  --gpu-memory-utilization 0.92 \
  --quantization ascend \
  --async-scheduling \
  --additional-config '{"ascend_scheduler_config":{"enabled":true}}' \
  --kv-transfer-config \
  '{
        "kv_connector": "MultiConnector",
        "kv_role": "kv_producer",
        "kv_connector_extra_config": {
                "use_layerwise": false,
                "connectors": [
                        {
                                "kv_connector": "MooncakeConnectorV1",
                                "kv_role": "kv_producer",
                                "kv_port": "30000",
                                "kv_connector_extra_config": {
                                        "use_ascend_direct": true,
                                        "prefill": {
                                                "dp_size": 2,
                                                "tp_size": 8
                                        },
                                        "decode": {
                                                "dp_size": 4,
                                                "tp_size": 4
                                        }
                                }
                        },
			{
                                "kv_connector": "AscendStoreConnector",
                                "kv_role": "kv_producer",
                                "kv_connector_extra_config": {
                                        "backend": "mooncake",
                                        "mooncake_rpc_port":"0"
                                }
                        }

                ]
        }
  }'
```

D节点:
```
vllm serve /mnt/share/DeepSeek-V3.2-Exp-W8A8 \
  --host 0.0.0.0 \
  --port 8003 \
  --data-parallel-size 4 \
  --tensor-parallel-size 4 \
  --enable-expert-parallel \
  --seed 1024 \
  --served-model-name model \
  --max-model-len 8192 \
  --max-num-batched-tokens 8192 \
  --max-num-seqs 16 \
  --enforce-eager \
  --trust-remote-code \
  --gpu-memory-utilization 0.92  \
  --quantization ascend \
  --async-scheduling \
  --additional-config '{"ascend_scheduler_config":{"enabled":true}}' \
  --kv-transfer-config \
  '{
        "kv_connector": "MultiConnector",
        "kv_role": "kv_consumer",
        "kv_connector_extra_config": {
                "use_layerwise": false,
                "connectors": [
                        {
                                "kv_connector": "MooncakeConnectorV1",
                                "kv_role": "kv_consumer",
                                "kv_port": "30100",
                                "kv_connector_extra_config": {
                                        "use_ascend_direct": true,
                                        "prefill": {
                                                "dp_size": 2,
                                                "tp_size": 8
                                        },
                                        "decode": {
                                                "dp_size": 4,
                                                "tp_size": 4
                                        }
                                }
                        },{

                                "kv_connector": "AscendStoreConnector",
                                "kv_role": "kv_consumer",
                                "kv_connector_extra_config": {
                                        "backend": "mooncake",
                                        "mooncake_rpc_port":"1"
                                }
                        }

                ]
        }
  }'
```
- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0

---------

Signed-off-by: lty <linhebiwen@gmail.com>
This commit is contained in:
lty
2026-02-07 09:27:15 +08:00
committed by GitHub
parent 4f33e25046
commit c3db1aca2f
2 changed files with 42 additions and 134 deletions

View File

@@ -333,7 +333,6 @@ class KVCacheRecvingThread(threading.Thread):
self.block_len = block_len
# TODO(jianzs): find a better way to detect MLA.
self.use_mla = len(block_len) == 2
self.use_sparse = len(block_len) == 3
self.request_queue: queue.Queue[Any] = queue.Queue()
self.executor = ThreadPoolExecutor(max_workers=32)
@@ -529,15 +528,11 @@ class KVCacheRecvingThread(threading.Thread):
req_start_time = time.perf_counter()
src_list, dst_list, length_list = [], [], []
block_length = len(self.block_len)
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)
):
if self.use_mla:
block_len = self.block_len[k % 2]
elif self.use_sparse:
block_len = self.block_len[k % 3]
else:
block_len = self.block_len[0]
block_len = self.block_len[k % block_length]
inner_block_len = block_len // tp_num_need_pulls
for remote_block_id, local_block_id in zip(grouped_remote_block_ids, grouped_local_block_ids):
src = src_layer_base_addr + local_block_id[0] * block_len + inner_offset * inner_block_len
@@ -1196,51 +1191,25 @@ class MooncakeConnectorWorker:
first_kv_cache_tuple[0].size(-1) != first_kv_cache_tuple[1].size(-1) and len(first_kv_cache_tuple) == 2
)
self.use_sparse = len(first_kv_cache_tuple) == 3
if self.use_mla:
# MLA case.[num_block, block_size, 1, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
self.num_blocks = first_kv_cache.shape[0]
logger.info("num_blocks: %s", self.num_blocks)
self.block_len = []
if self.use_mla or self.use_sparse:
block_rank = 3 # [block_size, latent_dim]
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
self.block_len = [
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
]
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
self.num_blocks,
block_shape_norm,
block_shape_pe,
)
elif self.use_sparse:
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:]
self.block_len = [
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
first_kv_cache[2].element_size() * math.prod(block_shape_k),
]
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s",
self.num_blocks,
block_shape_norm,
block_shape_pe,
block_shape_k,
)
for i in range(len(first_kv_cache_tuple)):
block_shape = first_kv_cache_tuple[i].shape[-block_rank:]
logger.info("block_shape: %s", block_shape)
self.block_len.append(first_kv_cache[i].element_size() * math.prod(block_shape))
else:
# eager:[num_block, block_size, num_head, hidden_dim]
# torchair:[num_block, block_size, num_head*hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
kv_elem_size = first_kv_cache.element_size()
block_rank = (
len(first_kv_cache.shape) - 1
) # [block_size, kv_heads, head_dim] or [block_size, kv_heads*head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
self.block_len = [kv_elem_size * math.prod(block_shape)]
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape)
logger.info("block_shape: %s", block_shape)
self.block_len = [first_kv_cache.element_size() * math.prod(block_shape)]
logger.info(
"Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s",
self.use_mla,
@@ -1252,30 +1221,15 @@ class MooncakeConnectorWorker:
kv_caches_base_addr = []
ptrs = []
lengths = []
length = len(self.block_len)
for cache_or_caches in kv_caches.values():
# Normalize to always be a list of caches
if self.use_mla:
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % 2]
kv_caches_base_addr.append(base_addr)
ptrs.append(base_addr)
lengths.append(region_len)
elif self.use_sparse:
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % 3]
kv_caches_base_addr.append(base_addr)
ptrs.append(base_addr)
lengths.append(region_len)
else:
cache_list = [cache_or_caches] if self.use_mla or self.use_sparse else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[0]
kv_caches_base_addr.append(base_addr)
ptrs.append(base_addr)
lengths.append(region_len)
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % length]
kv_caches_base_addr.append(base_addr)
ptrs.append(base_addr)
lengths.append(region_len)
global_te.register_buffer(ptrs, lengths)
# After KV Caches registered, start the sending or receiving thread.
metadata = MooncakeAgentMetadata(

View File

@@ -233,15 +233,11 @@ class KVCacheSendingLayerThread(threading.Thread):
remote_block_ids, local_block_ids
)
block_length = len(self.block_len)
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)
):
if self.use_mla:
block_len = self.block_len[k % 2]
elif self.use_sparse:
block_len = self.block_len[k % 3]
else:
block_len = self.block_len[0]
block_len = self.block_len[k % block_length]
for group_remote_block_id, group_local_block_id in zip(
grouped_remote_block_ids, grouped_local_block_ids
):
@@ -925,48 +921,21 @@ class MooncakeLayerwiseConnectorWorker:
first_kv_cache_tuple[0].size(-1) != first_kv_cache_tuple[1].size(-1) and len(first_kv_cache_tuple) == 2
)
self.use_sparse = len(first_kv_cache_tuple) == 3
if self.use_mla:
# MLA case.[num_block, block_size, 1, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
self.block_len = [
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
]
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
self.num_blocks,
block_shape_norm,
block_shape_pe,
)
elif self.use_sparse:
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:]
self.block_len = [
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
first_kv_cache[2].element_size() * math.prod(block_shape_k),
]
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s",
self.num_blocks,
block_shape_norm,
block_shape_pe,
block_shape_k,
)
self.num_blocks = first_kv_cache.shape[0]
logger.info("num_blocks: %s", self.num_blocks)
block_rank = 3
self.block_len = []
if self.use_mla or self.use_sparse:
for i in range(len(first_kv_cache_tuple)):
block_shape = first_kv_cache_tuple[i].shape[-block_rank:]
logger.info("block_shape: %s", block_shape)
self.block_len.append(first_kv_cache[i].element_size() * math.prod(block_shape))
else:
# [num_block, block_size, num_head, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
kv_elem_size = first_kv_cache.element_size()
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
self.block_len = [kv_elem_size * math.prod(block_shape)]
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape)
logger.info("block_shape: %s", block_shape)
self.block_len = [first_kv_cache.element_size() * math.prod(block_shape)]
logger.info(
"Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s",
@@ -979,30 +948,15 @@ class MooncakeLayerwiseConnectorWorker:
kv_caches_base_addr = []
ptrs = []
lengths = []
length = len(self.block_len)
for cache_or_caches in kv_caches.values():
# Normalize to always be a list of caches
if self.use_mla:
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % 2]
kv_caches_base_addr.append(base_addr)
ptrs.append(base_addr)
lengths.append(region_len)
elif self.use_sparse:
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % 3]
kv_caches_base_addr.append(base_addr)
ptrs.append(base_addr)
lengths.append(region_len)
else:
cache_list = [cache_or_caches] if self.use_mla or self.use_sparse else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[0]
kv_caches_base_addr.append(base_addr)
ptrs.append(base_addr)
lengths.append(region_len)
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % length]
kv_caches_base_addr.append(base_addr)
ptrs.append(base_addr)
lengths.append(region_len)
global_te.register_buffer(ptrs, lengths)
self.kv_caches_base_addr = kv_caches_base_addr