[Refactor]refactor p2p connector (#6551)
### What this PR does / why we need it?
Redundant code is removed, and repeated logic is combined through the
p2p connector refactor, making the code easy to extend.
### Does this PR introduce _any_ user-facing change?
NA
### How was this patch tested?
P节点:
```
vllm serve /mnt/weight/DeepSeek-V3.2-Exp-W8A8 \
--host 0.0.0.0 \
--port 8002 \
--data-parallel-size 2 \
--tensor-parallel-size 8 \
--enable-expert-parallel \
--seed 1024 \
--served-model-name model \
--max-model-len 8192 \
--max-num-batched-tokens 8192 \
--max-num-seqs 16 \
--enforce-eager \
--trust-remote-code \
--gpu-memory-utilization 0.92 \
--quantization ascend \
--async-scheduling \
--additional-config '{"ascend_scheduler_config":{"enabled":true}}' \
--kv-transfer-config \
'{
"kv_connector": "MultiConnector",
"kv_role": "kv_producer",
"kv_connector_extra_config": {
"use_layerwise": false,
"connectors": [
{
"kv_connector": "MooncakeConnectorV1",
"kv_role": "kv_producer",
"kv_port": "30000",
"kv_connector_extra_config": {
"use_ascend_direct": true,
"prefill": {
"dp_size": 2,
"tp_size": 8
},
"decode": {
"dp_size": 4,
"tp_size": 4
}
}
},
{
"kv_connector": "AscendStoreConnector",
"kv_role": "kv_producer",
"kv_connector_extra_config": {
"backend": "mooncake",
"mooncake_rpc_port":"0"
}
}
]
}
}'
```
D节点:
```
vllm serve /mnt/share/DeepSeek-V3.2-Exp-W8A8 \
--host 0.0.0.0 \
--port 8003 \
--data-parallel-size 4 \
--tensor-parallel-size 4 \
--enable-expert-parallel \
--seed 1024 \
--served-model-name model \
--max-model-len 8192 \
--max-num-batched-tokens 8192 \
--max-num-seqs 16 \
--enforce-eager \
--trust-remote-code \
--gpu-memory-utilization 0.92 \
--quantization ascend \
--async-scheduling \
--additional-config '{"ascend_scheduler_config":{"enabled":true}}' \
--kv-transfer-config \
'{
"kv_connector": "MultiConnector",
"kv_role": "kv_consumer",
"kv_connector_extra_config": {
"use_layerwise": false,
"connectors": [
{
"kv_connector": "MooncakeConnectorV1",
"kv_role": "kv_consumer",
"kv_port": "30100",
"kv_connector_extra_config": {
"use_ascend_direct": true,
"prefill": {
"dp_size": 2,
"tp_size": 8
},
"decode": {
"dp_size": 4,
"tp_size": 4
}
}
},{
"kv_connector": "AscendStoreConnector",
"kv_role": "kv_consumer",
"kv_connector_extra_config": {
"backend": "mooncake",
"mooncake_rpc_port":"1"
}
}
]
}
}'
```
- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0
---------
Signed-off-by: lty <linhebiwen@gmail.com>
This commit is contained in:
@@ -333,7 +333,6 @@ class KVCacheRecvingThread(threading.Thread):
|
|||||||
self.block_len = block_len
|
self.block_len = block_len
|
||||||
# TODO(jianzs): find a better way to detect MLA.
|
# TODO(jianzs): find a better way to detect MLA.
|
||||||
self.use_mla = len(block_len) == 2
|
self.use_mla = len(block_len) == 2
|
||||||
self.use_sparse = len(block_len) == 3
|
|
||||||
|
|
||||||
self.request_queue: queue.Queue[Any] = queue.Queue()
|
self.request_queue: queue.Queue[Any] = queue.Queue()
|
||||||
self.executor = ThreadPoolExecutor(max_workers=32)
|
self.executor = ThreadPoolExecutor(max_workers=32)
|
||||||
@@ -529,15 +528,11 @@ class KVCacheRecvingThread(threading.Thread):
|
|||||||
|
|
||||||
req_start_time = time.perf_counter()
|
req_start_time = time.perf_counter()
|
||||||
src_list, dst_list, length_list = [], [], []
|
src_list, dst_list, length_list = [], [], []
|
||||||
|
block_length = len(self.block_len)
|
||||||
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
|
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
|
||||||
zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)
|
zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)
|
||||||
):
|
):
|
||||||
if self.use_mla:
|
block_len = self.block_len[k % block_length]
|
||||||
block_len = self.block_len[k % 2]
|
|
||||||
elif self.use_sparse:
|
|
||||||
block_len = self.block_len[k % 3]
|
|
||||||
else:
|
|
||||||
block_len = self.block_len[0]
|
|
||||||
inner_block_len = block_len // tp_num_need_pulls
|
inner_block_len = block_len // tp_num_need_pulls
|
||||||
for remote_block_id, local_block_id in zip(grouped_remote_block_ids, grouped_local_block_ids):
|
for remote_block_id, local_block_id in zip(grouped_remote_block_ids, grouped_local_block_ids):
|
||||||
src = src_layer_base_addr + local_block_id[0] * block_len + inner_offset * inner_block_len
|
src = src_layer_base_addr + local_block_id[0] * block_len + inner_offset * inner_block_len
|
||||||
@@ -1196,51 +1191,25 @@ class MooncakeConnectorWorker:
|
|||||||
first_kv_cache_tuple[0].size(-1) != first_kv_cache_tuple[1].size(-1) and len(first_kv_cache_tuple) == 2
|
first_kv_cache_tuple[0].size(-1) != first_kv_cache_tuple[1].size(-1) and len(first_kv_cache_tuple) == 2
|
||||||
)
|
)
|
||||||
self.use_sparse = len(first_kv_cache_tuple) == 3
|
self.use_sparse = len(first_kv_cache_tuple) == 3
|
||||||
if self.use_mla:
|
|
||||||
# MLA case.[num_block, block_size, 1, hidden_dim]
|
self.num_blocks = first_kv_cache.shape[0]
|
||||||
self.num_blocks = first_kv_cache.shape[0]
|
logger.info("num_blocks: %s", self.num_blocks)
|
||||||
|
self.block_len = []
|
||||||
|
if self.use_mla or self.use_sparse:
|
||||||
block_rank = 3 # [block_size, latent_dim]
|
block_rank = 3 # [block_size, latent_dim]
|
||||||
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
for i in range(len(first_kv_cache_tuple)):
|
||||||
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
block_shape = first_kv_cache_tuple[i].shape[-block_rank:]
|
||||||
self.block_len = [
|
logger.info("block_shape: %s", block_shape)
|
||||||
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
self.block_len.append(first_kv_cache[i].element_size() * math.prod(block_shape))
|
||||||
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
|
|
||||||
]
|
|
||||||
logger.info(
|
|
||||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
|
||||||
self.num_blocks,
|
|
||||||
block_shape_norm,
|
|
||||||
block_shape_pe,
|
|
||||||
)
|
|
||||||
elif self.use_sparse:
|
|
||||||
self.num_blocks = first_kv_cache.shape[0]
|
|
||||||
block_rank = 3 # [block_size, latent_dim]
|
|
||||||
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
|
||||||
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
|
||||||
block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:]
|
|
||||||
self.block_len = [
|
|
||||||
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
|
||||||
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
|
|
||||||
first_kv_cache[2].element_size() * math.prod(block_shape_k),
|
|
||||||
]
|
|
||||||
logger.info(
|
|
||||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s",
|
|
||||||
self.num_blocks,
|
|
||||||
block_shape_norm,
|
|
||||||
block_shape_pe,
|
|
||||||
block_shape_k,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# eager:[num_block, block_size, num_head, hidden_dim]
|
# eager:[num_block, block_size, num_head, hidden_dim]
|
||||||
# torchair:[num_block, block_size, num_head*hidden_dim]
|
|
||||||
self.num_blocks = first_kv_cache.shape[0]
|
|
||||||
kv_elem_size = first_kv_cache.element_size()
|
|
||||||
block_rank = (
|
block_rank = (
|
||||||
len(first_kv_cache.shape) - 1
|
len(first_kv_cache.shape) - 1
|
||||||
) # [block_size, kv_heads, head_dim] or [block_size, kv_heads*head_dim]
|
) # [block_size, kv_heads, head_dim] or [block_size, kv_heads*head_dim]
|
||||||
block_shape = first_kv_cache.shape[-block_rank:]
|
block_shape = first_kv_cache.shape[-block_rank:]
|
||||||
self.block_len = [kv_elem_size * math.prod(block_shape)]
|
logger.info("block_shape: %s", block_shape)
|
||||||
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape)
|
self.block_len = [first_kv_cache.element_size() * math.prod(block_shape)]
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s",
|
"Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s",
|
||||||
self.use_mla,
|
self.use_mla,
|
||||||
@@ -1252,30 +1221,15 @@ class MooncakeConnectorWorker:
|
|||||||
kv_caches_base_addr = []
|
kv_caches_base_addr = []
|
||||||
ptrs = []
|
ptrs = []
|
||||||
lengths = []
|
lengths = []
|
||||||
|
length = len(self.block_len)
|
||||||
for cache_or_caches in kv_caches.values():
|
for cache_or_caches in kv_caches.values():
|
||||||
# Normalize to always be a list of caches
|
# Normalize to always be a list of caches
|
||||||
if self.use_mla:
|
for i, cache in enumerate(cache_or_caches, 0):
|
||||||
for i, cache in enumerate(cache_or_caches, 0):
|
base_addr = cache.data_ptr()
|
||||||
base_addr = cache.data_ptr()
|
region_len = self.num_blocks * self.block_len[i % length]
|
||||||
region_len = self.num_blocks * self.block_len[i % 2]
|
kv_caches_base_addr.append(base_addr)
|
||||||
kv_caches_base_addr.append(base_addr)
|
ptrs.append(base_addr)
|
||||||
ptrs.append(base_addr)
|
lengths.append(region_len)
|
||||||
lengths.append(region_len)
|
|
||||||
elif self.use_sparse:
|
|
||||||
for i, cache in enumerate(cache_or_caches, 0):
|
|
||||||
base_addr = cache.data_ptr()
|
|
||||||
region_len = self.num_blocks * self.block_len[i % 3]
|
|
||||||
kv_caches_base_addr.append(base_addr)
|
|
||||||
ptrs.append(base_addr)
|
|
||||||
lengths.append(region_len)
|
|
||||||
else:
|
|
||||||
cache_list = [cache_or_caches] if self.use_mla or self.use_sparse else cache_or_caches
|
|
||||||
for cache in cache_list:
|
|
||||||
base_addr = cache.data_ptr()
|
|
||||||
region_len = self.num_blocks * self.block_len[0]
|
|
||||||
kv_caches_base_addr.append(base_addr)
|
|
||||||
ptrs.append(base_addr)
|
|
||||||
lengths.append(region_len)
|
|
||||||
global_te.register_buffer(ptrs, lengths)
|
global_te.register_buffer(ptrs, lengths)
|
||||||
# After KV Caches registered, start the sending or receiving thread.
|
# After KV Caches registered, start the sending or receiving thread.
|
||||||
metadata = MooncakeAgentMetadata(
|
metadata = MooncakeAgentMetadata(
|
||||||
|
|||||||
@@ -233,15 +233,11 @@ class KVCacheSendingLayerThread(threading.Thread):
|
|||||||
remote_block_ids, local_block_ids
|
remote_block_ids, local_block_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
|
block_length = len(self.block_len)
|
||||||
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
|
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
|
||||||
zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)
|
zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)
|
||||||
):
|
):
|
||||||
if self.use_mla:
|
block_len = self.block_len[k % block_length]
|
||||||
block_len = self.block_len[k % 2]
|
|
||||||
elif self.use_sparse:
|
|
||||||
block_len = self.block_len[k % 3]
|
|
||||||
else:
|
|
||||||
block_len = self.block_len[0]
|
|
||||||
for group_remote_block_id, group_local_block_id in zip(
|
for group_remote_block_id, group_local_block_id in zip(
|
||||||
grouped_remote_block_ids, grouped_local_block_ids
|
grouped_remote_block_ids, grouped_local_block_ids
|
||||||
):
|
):
|
||||||
@@ -925,48 +921,21 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
first_kv_cache_tuple[0].size(-1) != first_kv_cache_tuple[1].size(-1) and len(first_kv_cache_tuple) == 2
|
first_kv_cache_tuple[0].size(-1) != first_kv_cache_tuple[1].size(-1) and len(first_kv_cache_tuple) == 2
|
||||||
)
|
)
|
||||||
self.use_sparse = len(first_kv_cache_tuple) == 3
|
self.use_sparse = len(first_kv_cache_tuple) == 3
|
||||||
if self.use_mla:
|
|
||||||
# MLA case.[num_block, block_size, 1, hidden_dim]
|
self.num_blocks = first_kv_cache.shape[0]
|
||||||
self.num_blocks = first_kv_cache.shape[0]
|
logger.info("num_blocks: %s", self.num_blocks)
|
||||||
block_rank = 3 # [block_size, latent_dim]
|
block_rank = 3
|
||||||
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
self.block_len = []
|
||||||
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
if self.use_mla or self.use_sparse:
|
||||||
self.block_len = [
|
for i in range(len(first_kv_cache_tuple)):
|
||||||
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
block_shape = first_kv_cache_tuple[i].shape[-block_rank:]
|
||||||
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
|
logger.info("block_shape: %s", block_shape)
|
||||||
]
|
self.block_len.append(first_kv_cache[i].element_size() * math.prod(block_shape))
|
||||||
logger.info(
|
|
||||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
|
||||||
self.num_blocks,
|
|
||||||
block_shape_norm,
|
|
||||||
block_shape_pe,
|
|
||||||
)
|
|
||||||
elif self.use_sparse:
|
|
||||||
self.num_blocks = first_kv_cache.shape[0]
|
|
||||||
block_rank = 3 # [block_size, latent_dim]
|
|
||||||
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
|
||||||
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
|
||||||
block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:]
|
|
||||||
self.block_len = [
|
|
||||||
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
|
||||||
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
|
|
||||||
first_kv_cache[2].element_size() * math.prod(block_shape_k),
|
|
||||||
]
|
|
||||||
logger.info(
|
|
||||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s",
|
|
||||||
self.num_blocks,
|
|
||||||
block_shape_norm,
|
|
||||||
block_shape_pe,
|
|
||||||
block_shape_k,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# [num_block, block_size, num_head, hidden_dim]
|
# [num_block, block_size, num_head, hidden_dim]
|
||||||
self.num_blocks = first_kv_cache.shape[0]
|
|
||||||
kv_elem_size = first_kv_cache.element_size()
|
|
||||||
block_rank = 3 # [block_size, kv_heads, head_dim]
|
|
||||||
block_shape = first_kv_cache.shape[-block_rank:]
|
block_shape = first_kv_cache.shape[-block_rank:]
|
||||||
self.block_len = [kv_elem_size * math.prod(block_shape)]
|
logger.info("block_shape: %s", block_shape)
|
||||||
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape)
|
self.block_len = [first_kv_cache.element_size() * math.prod(block_shape)]
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s",
|
"Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s",
|
||||||
@@ -979,30 +948,15 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
kv_caches_base_addr = []
|
kv_caches_base_addr = []
|
||||||
ptrs = []
|
ptrs = []
|
||||||
lengths = []
|
lengths = []
|
||||||
|
length = len(self.block_len)
|
||||||
for cache_or_caches in kv_caches.values():
|
for cache_or_caches in kv_caches.values():
|
||||||
# Normalize to always be a list of caches
|
# Normalize to always be a list of caches
|
||||||
if self.use_mla:
|
for i, cache in enumerate(cache_or_caches, 0):
|
||||||
for i, cache in enumerate(cache_or_caches, 0):
|
base_addr = cache.data_ptr()
|
||||||
base_addr = cache.data_ptr()
|
region_len = self.num_blocks * self.block_len[i % length]
|
||||||
region_len = self.num_blocks * self.block_len[i % 2]
|
kv_caches_base_addr.append(base_addr)
|
||||||
kv_caches_base_addr.append(base_addr)
|
ptrs.append(base_addr)
|
||||||
ptrs.append(base_addr)
|
lengths.append(region_len)
|
||||||
lengths.append(region_len)
|
|
||||||
elif self.use_sparse:
|
|
||||||
for i, cache in enumerate(cache_or_caches, 0):
|
|
||||||
base_addr = cache.data_ptr()
|
|
||||||
region_len = self.num_blocks * self.block_len[i % 3]
|
|
||||||
kv_caches_base_addr.append(base_addr)
|
|
||||||
ptrs.append(base_addr)
|
|
||||||
lengths.append(region_len)
|
|
||||||
else:
|
|
||||||
cache_list = [cache_or_caches] if self.use_mla or self.use_sparse else cache_or_caches
|
|
||||||
for cache in cache_list:
|
|
||||||
base_addr = cache.data_ptr()
|
|
||||||
region_len = self.num_blocks * self.block_len[0]
|
|
||||||
kv_caches_base_addr.append(base_addr)
|
|
||||||
ptrs.append(base_addr)
|
|
||||||
lengths.append(region_len)
|
|
||||||
global_te.register_buffer(ptrs, lengths)
|
global_te.register_buffer(ptrs, lengths)
|
||||||
self.kv_caches_base_addr = kv_caches_base_addr
|
self.kv_caches_base_addr = kv_caches_base_addr
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user