[Refactor]refactor p2p connector (#6551)
### What this PR does / why we need it?
Redundant code is removed, and repeated logic is combined through the
p2p connector refactor, making the code easy to extend.
### Does this PR introduce _any_ user-facing change?
NA
### How was this patch tested?
P节点:
```
vllm serve /mnt/weight/DeepSeek-V3.2-Exp-W8A8 \
--host 0.0.0.0 \
--port 8002 \
--data-parallel-size 2 \
--tensor-parallel-size 8 \
--enable-expert-parallel \
--seed 1024 \
--served-model-name model \
--max-model-len 8192 \
--max-num-batched-tokens 8192 \
--max-num-seqs 16 \
--enforce-eager \
--trust-remote-code \
--gpu-memory-utilization 0.92 \
--quantization ascend \
--async-scheduling \
--additional-config '{"ascend_scheduler_config":{"enabled":true}}' \
--kv-transfer-config \
'{
"kv_connector": "MultiConnector",
"kv_role": "kv_producer",
"kv_connector_extra_config": {
"use_layerwise": false,
"connectors": [
{
"kv_connector": "MooncakeConnectorV1",
"kv_role": "kv_producer",
"kv_port": "30000",
"kv_connector_extra_config": {
"use_ascend_direct": true,
"prefill": {
"dp_size": 2,
"tp_size": 8
},
"decode": {
"dp_size": 4,
"tp_size": 4
}
}
},
{
"kv_connector": "AscendStoreConnector",
"kv_role": "kv_producer",
"kv_connector_extra_config": {
"backend": "mooncake",
"mooncake_rpc_port":"0"
}
}
]
}
}'
```
D节点:
```
vllm serve /mnt/share/DeepSeek-V3.2-Exp-W8A8 \
--host 0.0.0.0 \
--port 8003 \
--data-parallel-size 4 \
--tensor-parallel-size 4 \
--enable-expert-parallel \
--seed 1024 \
--served-model-name model \
--max-model-len 8192 \
--max-num-batched-tokens 8192 \
--max-num-seqs 16 \
--enforce-eager \
--trust-remote-code \
--gpu-memory-utilization 0.92 \
--quantization ascend \
--async-scheduling \
--additional-config '{"ascend_scheduler_config":{"enabled":true}}' \
--kv-transfer-config \
'{
"kv_connector": "MultiConnector",
"kv_role": "kv_consumer",
"kv_connector_extra_config": {
"use_layerwise": false,
"connectors": [
{
"kv_connector": "MooncakeConnectorV1",
"kv_role": "kv_consumer",
"kv_port": "30100",
"kv_connector_extra_config": {
"use_ascend_direct": true,
"prefill": {
"dp_size": 2,
"tp_size": 8
},
"decode": {
"dp_size": 4,
"tp_size": 4
}
}
},{
"kv_connector": "AscendStoreConnector",
"kv_role": "kv_consumer",
"kv_connector_extra_config": {
"backend": "mooncake",
"mooncake_rpc_port":"1"
}
}
]
}
}'
```
- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0
---------
Signed-off-by: lty <linhebiwen@gmail.com>
This commit is contained in:
@@ -333,7 +333,6 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
self.block_len = block_len
|
||||
# TODO(jianzs): find a better way to detect MLA.
|
||||
self.use_mla = len(block_len) == 2
|
||||
self.use_sparse = len(block_len) == 3
|
||||
|
||||
self.request_queue: queue.Queue[Any] = queue.Queue()
|
||||
self.executor = ThreadPoolExecutor(max_workers=32)
|
||||
@@ -529,15 +528,11 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
|
||||
req_start_time = time.perf_counter()
|
||||
src_list, dst_list, length_list = [], [], []
|
||||
block_length = len(self.block_len)
|
||||
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
|
||||
zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)
|
||||
):
|
||||
if self.use_mla:
|
||||
block_len = self.block_len[k % 2]
|
||||
elif self.use_sparse:
|
||||
block_len = self.block_len[k % 3]
|
||||
else:
|
||||
block_len = self.block_len[0]
|
||||
block_len = self.block_len[k % block_length]
|
||||
inner_block_len = block_len // tp_num_need_pulls
|
||||
for remote_block_id, local_block_id in zip(grouped_remote_block_ids, grouped_local_block_ids):
|
||||
src = src_layer_base_addr + local_block_id[0] * block_len + inner_offset * inner_block_len
|
||||
@@ -1196,51 +1191,25 @@ class MooncakeConnectorWorker:
|
||||
first_kv_cache_tuple[0].size(-1) != first_kv_cache_tuple[1].size(-1) and len(first_kv_cache_tuple) == 2
|
||||
)
|
||||
self.use_sparse = len(first_kv_cache_tuple) == 3
|
||||
if self.use_mla:
|
||||
# MLA case.[num_block, block_size, 1, hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
logger.info("num_blocks: %s", self.num_blocks)
|
||||
self.block_len = []
|
||||
if self.use_mla or self.use_sparse:
|
||||
block_rank = 3 # [block_size, latent_dim]
|
||||
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
||||
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
||||
self.block_len = [
|
||||
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
||||
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
|
||||
]
|
||||
logger.info(
|
||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
||||
self.num_blocks,
|
||||
block_shape_norm,
|
||||
block_shape_pe,
|
||||
)
|
||||
elif self.use_sparse:
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 3 # [block_size, latent_dim]
|
||||
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
||||
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
||||
block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:]
|
||||
self.block_len = [
|
||||
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
||||
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
|
||||
first_kv_cache[2].element_size() * math.prod(block_shape_k),
|
||||
]
|
||||
logger.info(
|
||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s",
|
||||
self.num_blocks,
|
||||
block_shape_norm,
|
||||
block_shape_pe,
|
||||
block_shape_k,
|
||||
)
|
||||
for i in range(len(first_kv_cache_tuple)):
|
||||
block_shape = first_kv_cache_tuple[i].shape[-block_rank:]
|
||||
logger.info("block_shape: %s", block_shape)
|
||||
self.block_len.append(first_kv_cache[i].element_size() * math.prod(block_shape))
|
||||
else:
|
||||
# eager:[num_block, block_size, num_head, hidden_dim]
|
||||
# torchair:[num_block, block_size, num_head*hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
kv_elem_size = first_kv_cache.element_size()
|
||||
block_rank = (
|
||||
len(first_kv_cache.shape) - 1
|
||||
) # [block_size, kv_heads, head_dim] or [block_size, kv_heads*head_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
self.block_len = [kv_elem_size * math.prod(block_shape)]
|
||||
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape)
|
||||
logger.info("block_shape: %s", block_shape)
|
||||
self.block_len = [first_kv_cache.element_size() * math.prod(block_shape)]
|
||||
|
||||
logger.info(
|
||||
"Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s",
|
||||
self.use_mla,
|
||||
@@ -1252,30 +1221,15 @@ class MooncakeConnectorWorker:
|
||||
kv_caches_base_addr = []
|
||||
ptrs = []
|
||||
lengths = []
|
||||
length = len(self.block_len)
|
||||
for cache_or_caches in kv_caches.values():
|
||||
# Normalize to always be a list of caches
|
||||
if self.use_mla:
|
||||
for i, cache in enumerate(cache_or_caches, 0):
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len[i % 2]
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
ptrs.append(base_addr)
|
||||
lengths.append(region_len)
|
||||
elif self.use_sparse:
|
||||
for i, cache in enumerate(cache_or_caches, 0):
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len[i % 3]
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
ptrs.append(base_addr)
|
||||
lengths.append(region_len)
|
||||
else:
|
||||
cache_list = [cache_or_caches] if self.use_mla or self.use_sparse else cache_or_caches
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len[0]
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
ptrs.append(base_addr)
|
||||
lengths.append(region_len)
|
||||
for i, cache in enumerate(cache_or_caches, 0):
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len[i % length]
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
ptrs.append(base_addr)
|
||||
lengths.append(region_len)
|
||||
global_te.register_buffer(ptrs, lengths)
|
||||
# After KV Caches registered, start the sending or receiving thread.
|
||||
metadata = MooncakeAgentMetadata(
|
||||
|
||||
@@ -233,15 +233,11 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
remote_block_ids, local_block_ids
|
||||
)
|
||||
|
||||
block_length = len(self.block_len)
|
||||
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
|
||||
zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)
|
||||
):
|
||||
if self.use_mla:
|
||||
block_len = self.block_len[k % 2]
|
||||
elif self.use_sparse:
|
||||
block_len = self.block_len[k % 3]
|
||||
else:
|
||||
block_len = self.block_len[0]
|
||||
block_len = self.block_len[k % block_length]
|
||||
for group_remote_block_id, group_local_block_id in zip(
|
||||
grouped_remote_block_ids, grouped_local_block_ids
|
||||
):
|
||||
@@ -925,48 +921,21 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
first_kv_cache_tuple[0].size(-1) != first_kv_cache_tuple[1].size(-1) and len(first_kv_cache_tuple) == 2
|
||||
)
|
||||
self.use_sparse = len(first_kv_cache_tuple) == 3
|
||||
if self.use_mla:
|
||||
# MLA case.[num_block, block_size, 1, hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 3 # [block_size, latent_dim]
|
||||
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
||||
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
||||
self.block_len = [
|
||||
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
||||
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
|
||||
]
|
||||
logger.info(
|
||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
||||
self.num_blocks,
|
||||
block_shape_norm,
|
||||
block_shape_pe,
|
||||
)
|
||||
elif self.use_sparse:
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 3 # [block_size, latent_dim]
|
||||
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
||||
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
||||
block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:]
|
||||
self.block_len = [
|
||||
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
||||
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
|
||||
first_kv_cache[2].element_size() * math.prod(block_shape_k),
|
||||
]
|
||||
logger.info(
|
||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s",
|
||||
self.num_blocks,
|
||||
block_shape_norm,
|
||||
block_shape_pe,
|
||||
block_shape_k,
|
||||
)
|
||||
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
logger.info("num_blocks: %s", self.num_blocks)
|
||||
block_rank = 3
|
||||
self.block_len = []
|
||||
if self.use_mla or self.use_sparse:
|
||||
for i in range(len(first_kv_cache_tuple)):
|
||||
block_shape = first_kv_cache_tuple[i].shape[-block_rank:]
|
||||
logger.info("block_shape: %s", block_shape)
|
||||
self.block_len.append(first_kv_cache[i].element_size() * math.prod(block_shape))
|
||||
else:
|
||||
# [num_block, block_size, num_head, hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
kv_elem_size = first_kv_cache.element_size()
|
||||
block_rank = 3 # [block_size, kv_heads, head_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
self.block_len = [kv_elem_size * math.prod(block_shape)]
|
||||
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape)
|
||||
logger.info("block_shape: %s", block_shape)
|
||||
self.block_len = [first_kv_cache.element_size() * math.prod(block_shape)]
|
||||
|
||||
logger.info(
|
||||
"Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s",
|
||||
@@ -979,30 +948,15 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
kv_caches_base_addr = []
|
||||
ptrs = []
|
||||
lengths = []
|
||||
length = len(self.block_len)
|
||||
for cache_or_caches in kv_caches.values():
|
||||
# Normalize to always be a list of caches
|
||||
if self.use_mla:
|
||||
for i, cache in enumerate(cache_or_caches, 0):
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len[i % 2]
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
ptrs.append(base_addr)
|
||||
lengths.append(region_len)
|
||||
elif self.use_sparse:
|
||||
for i, cache in enumerate(cache_or_caches, 0):
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len[i % 3]
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
ptrs.append(base_addr)
|
||||
lengths.append(region_len)
|
||||
else:
|
||||
cache_list = [cache_or_caches] if self.use_mla or self.use_sparse else cache_or_caches
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len[0]
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
ptrs.append(base_addr)
|
||||
lengths.append(region_len)
|
||||
for i, cache in enumerate(cache_or_caches, 0):
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len[i % length]
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
ptrs.append(base_addr)
|
||||
lengths.append(region_len)
|
||||
global_te.register_buffer(ptrs, lengths)
|
||||
self.kv_caches_base_addr = kv_caches_base_addr
|
||||
|
||||
|
||||
Reference in New Issue
Block a user