From c3db1aca2ffbbc1194ff89443b51f7f85e961de5 Mon Sep 17 00:00:00 2001 From: lty Date: Sat, 7 Feb 2026 09:27:15 +0800 Subject: [PATCH] [Refactor]refactor p2p connector (#6551) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? Redundant code is removed, and repeated logic is combined through the p2p connector refactor, making the code easy to extend. ### Does this PR introduce _any_ user-facing change? NA ### How was this patch tested? P节点: ``` vllm serve /mnt/weight/DeepSeek-V3.2-Exp-W8A8 \ --host 0.0.0.0 \ --port 8002 \ --data-parallel-size 2 \ --tensor-parallel-size 8 \ --enable-expert-parallel \ --seed 1024 \ --served-model-name model \ --max-model-len 8192 \ --max-num-batched-tokens 8192 \ --max-num-seqs 16 \ --enforce-eager \ --trust-remote-code \ --gpu-memory-utilization 0.92 \ --quantization ascend \ --async-scheduling \ --additional-config '{"ascend_scheduler_config":{"enabled":true}}' \ --kv-transfer-config \ '{ "kv_connector": "MultiConnector", "kv_role": "kv_producer", "kv_connector_extra_config": { "use_layerwise": false, "connectors": [ { "kv_connector": "MooncakeConnectorV1", "kv_role": "kv_producer", "kv_port": "30000", "kv_connector_extra_config": { "use_ascend_direct": true, "prefill": { "dp_size": 2, "tp_size": 8 }, "decode": { "dp_size": 4, "tp_size": 4 } } }, { "kv_connector": "AscendStoreConnector", "kv_role": "kv_producer", "kv_connector_extra_config": { "backend": "mooncake", "mooncake_rpc_port":"0" } } ] } }' ``` D节点: ``` vllm serve /mnt/share/DeepSeek-V3.2-Exp-W8A8 \ --host 0.0.0.0 \ --port 8003 \ --data-parallel-size 4 \ --tensor-parallel-size 4 \ --enable-expert-parallel \ --seed 1024 \ --served-model-name model \ --max-model-len 8192 \ --max-num-batched-tokens 8192 \ --max-num-seqs 16 \ --enforce-eager \ --trust-remote-code \ --gpu-memory-utilization 0.92 \ --quantization ascend \ --async-scheduling \ --additional-config '{"ascend_scheduler_config":{"enabled":true}}' \ --kv-transfer-config \ '{ "kv_connector": "MultiConnector", "kv_role": "kv_consumer", "kv_connector_extra_config": { "use_layerwise": false, "connectors": [ { "kv_connector": "MooncakeConnectorV1", "kv_role": "kv_consumer", "kv_port": "30100", "kv_connector_extra_config": { "use_ascend_direct": true, "prefill": { "dp_size": 2, "tp_size": 8 }, "decode": { "dp_size": 4, "tp_size": 4 } } },{ "kv_connector": "AscendStoreConnector", "kv_role": "kv_consumer", "kv_connector_extra_config": { "backend": "mooncake", "mooncake_rpc_port":"1" } } ] } }' ``` - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0 --------- Signed-off-by: lty --- .../kv_transfer/kv_p2p/mooncake_connector.py | 88 +++++-------------- .../kv_p2p/mooncake_layerwise_connector.py | 88 +++++-------------- 2 files changed, 42 insertions(+), 134 deletions(-) diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py index ca164ad8..5fc64590 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py @@ -333,7 +333,6 @@ class KVCacheRecvingThread(threading.Thread): self.block_len = block_len # TODO(jianzs): find a better way to detect MLA. self.use_mla = len(block_len) == 2 - self.use_sparse = len(block_len) == 3 self.request_queue: queue.Queue[Any] = queue.Queue() self.executor = ThreadPoolExecutor(max_workers=32) @@ -529,15 +528,11 @@ class KVCacheRecvingThread(threading.Thread): req_start_time = time.perf_counter() src_list, dst_list, length_list = [], [], [] + block_length = len(self.block_len) for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs) ): - if self.use_mla: - block_len = self.block_len[k % 2] - elif self.use_sparse: - block_len = self.block_len[k % 3] - else: - block_len = self.block_len[0] + block_len = self.block_len[k % block_length] inner_block_len = block_len // tp_num_need_pulls for remote_block_id, local_block_id in zip(grouped_remote_block_ids, grouped_local_block_ids): src = src_layer_base_addr + local_block_id[0] * block_len + inner_offset * inner_block_len @@ -1196,51 +1191,25 @@ class MooncakeConnectorWorker: first_kv_cache_tuple[0].size(-1) != first_kv_cache_tuple[1].size(-1) and len(first_kv_cache_tuple) == 2 ) self.use_sparse = len(first_kv_cache_tuple) == 3 - if self.use_mla: - # MLA case.[num_block, block_size, 1, hidden_dim] - self.num_blocks = first_kv_cache.shape[0] + + self.num_blocks = first_kv_cache.shape[0] + logger.info("num_blocks: %s", self.num_blocks) + self.block_len = [] + if self.use_mla or self.use_sparse: block_rank = 3 # [block_size, latent_dim] - block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] - block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] - self.block_len = [ - first_kv_cache[0].element_size() * math.prod(block_shape_norm), - first_kv_cache[1].element_size() * math.prod(block_shape_pe), - ] - logger.info( - "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", - self.num_blocks, - block_shape_norm, - block_shape_pe, - ) - elif self.use_sparse: - self.num_blocks = first_kv_cache.shape[0] - block_rank = 3 # [block_size, latent_dim] - block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] - block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] - block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:] - self.block_len = [ - first_kv_cache[0].element_size() * math.prod(block_shape_norm), - first_kv_cache[1].element_size() * math.prod(block_shape_pe), - first_kv_cache[2].element_size() * math.prod(block_shape_k), - ] - logger.info( - "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s", - self.num_blocks, - block_shape_norm, - block_shape_pe, - block_shape_k, - ) + for i in range(len(first_kv_cache_tuple)): + block_shape = first_kv_cache_tuple[i].shape[-block_rank:] + logger.info("block_shape: %s", block_shape) + self.block_len.append(first_kv_cache[i].element_size() * math.prod(block_shape)) else: # eager:[num_block, block_size, num_head, hidden_dim] - # torchair:[num_block, block_size, num_head*hidden_dim] - self.num_blocks = first_kv_cache.shape[0] - kv_elem_size = first_kv_cache.element_size() block_rank = ( len(first_kv_cache.shape) - 1 ) # [block_size, kv_heads, head_dim] or [block_size, kv_heads*head_dim] block_shape = first_kv_cache.shape[-block_rank:] - self.block_len = [kv_elem_size * math.prod(block_shape)] - logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape) + logger.info("block_shape: %s", block_shape) + self.block_len = [first_kv_cache.element_size() * math.prod(block_shape)] + logger.info( "Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s", self.use_mla, @@ -1252,30 +1221,15 @@ class MooncakeConnectorWorker: kv_caches_base_addr = [] ptrs = [] lengths = [] + length = len(self.block_len) for cache_or_caches in kv_caches.values(): # Normalize to always be a list of caches - if self.use_mla: - for i, cache in enumerate(cache_or_caches, 0): - base_addr = cache.data_ptr() - region_len = self.num_blocks * self.block_len[i % 2] - kv_caches_base_addr.append(base_addr) - ptrs.append(base_addr) - lengths.append(region_len) - elif self.use_sparse: - for i, cache in enumerate(cache_or_caches, 0): - base_addr = cache.data_ptr() - region_len = self.num_blocks * self.block_len[i % 3] - kv_caches_base_addr.append(base_addr) - ptrs.append(base_addr) - lengths.append(region_len) - else: - cache_list = [cache_or_caches] if self.use_mla or self.use_sparse else cache_or_caches - for cache in cache_list: - base_addr = cache.data_ptr() - region_len = self.num_blocks * self.block_len[0] - kv_caches_base_addr.append(base_addr) - ptrs.append(base_addr) - lengths.append(region_len) + for i, cache in enumerate(cache_or_caches, 0): + base_addr = cache.data_ptr() + region_len = self.num_blocks * self.block_len[i % length] + kv_caches_base_addr.append(base_addr) + ptrs.append(base_addr) + lengths.append(region_len) global_te.register_buffer(ptrs, lengths) # After KV Caches registered, start the sending or receiving thread. metadata = MooncakeAgentMetadata( diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py index 84259848..03584cd3 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py @@ -233,15 +233,11 @@ class KVCacheSendingLayerThread(threading.Thread): remote_block_ids, local_block_ids ) + block_length = len(self.block_len) for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( zip(layer_local_kv_base_addr, layer_remote_kv_base_addr) ): - if self.use_mla: - block_len = self.block_len[k % 2] - elif self.use_sparse: - block_len = self.block_len[k % 3] - else: - block_len = self.block_len[0] + block_len = self.block_len[k % block_length] for group_remote_block_id, group_local_block_id in zip( grouped_remote_block_ids, grouped_local_block_ids ): @@ -925,48 +921,21 @@ class MooncakeLayerwiseConnectorWorker: first_kv_cache_tuple[0].size(-1) != first_kv_cache_tuple[1].size(-1) and len(first_kv_cache_tuple) == 2 ) self.use_sparse = len(first_kv_cache_tuple) == 3 - if self.use_mla: - # MLA case.[num_block, block_size, 1, hidden_dim] - self.num_blocks = first_kv_cache.shape[0] - block_rank = 3 # [block_size, latent_dim] - block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] - block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] - self.block_len = [ - first_kv_cache[0].element_size() * math.prod(block_shape_norm), - first_kv_cache[1].element_size() * math.prod(block_shape_pe), - ] - logger.info( - "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", - self.num_blocks, - block_shape_norm, - block_shape_pe, - ) - elif self.use_sparse: - self.num_blocks = first_kv_cache.shape[0] - block_rank = 3 # [block_size, latent_dim] - block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] - block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] - block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:] - self.block_len = [ - first_kv_cache[0].element_size() * math.prod(block_shape_norm), - first_kv_cache[1].element_size() * math.prod(block_shape_pe), - first_kv_cache[2].element_size() * math.prod(block_shape_k), - ] - logger.info( - "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s", - self.num_blocks, - block_shape_norm, - block_shape_pe, - block_shape_k, - ) + + self.num_blocks = first_kv_cache.shape[0] + logger.info("num_blocks: %s", self.num_blocks) + block_rank = 3 + self.block_len = [] + if self.use_mla or self.use_sparse: + for i in range(len(first_kv_cache_tuple)): + block_shape = first_kv_cache_tuple[i].shape[-block_rank:] + logger.info("block_shape: %s", block_shape) + self.block_len.append(first_kv_cache[i].element_size() * math.prod(block_shape)) else: # [num_block, block_size, num_head, hidden_dim] - self.num_blocks = first_kv_cache.shape[0] - kv_elem_size = first_kv_cache.element_size() - block_rank = 3 # [block_size, kv_heads, head_dim] block_shape = first_kv_cache.shape[-block_rank:] - self.block_len = [kv_elem_size * math.prod(block_shape)] - logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape) + logger.info("block_shape: %s", block_shape) + self.block_len = [first_kv_cache.element_size() * math.prod(block_shape)] logger.info( "Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s", @@ -979,30 +948,15 @@ class MooncakeLayerwiseConnectorWorker: kv_caches_base_addr = [] ptrs = [] lengths = [] + length = len(self.block_len) for cache_or_caches in kv_caches.values(): # Normalize to always be a list of caches - if self.use_mla: - for i, cache in enumerate(cache_or_caches, 0): - base_addr = cache.data_ptr() - region_len = self.num_blocks * self.block_len[i % 2] - kv_caches_base_addr.append(base_addr) - ptrs.append(base_addr) - lengths.append(region_len) - elif self.use_sparse: - for i, cache in enumerate(cache_or_caches, 0): - base_addr = cache.data_ptr() - region_len = self.num_blocks * self.block_len[i % 3] - kv_caches_base_addr.append(base_addr) - ptrs.append(base_addr) - lengths.append(region_len) - else: - cache_list = [cache_or_caches] if self.use_mla or self.use_sparse else cache_or_caches - for cache in cache_list: - base_addr = cache.data_ptr() - region_len = self.num_blocks * self.block_len[0] - kv_caches_base_addr.append(base_addr) - ptrs.append(base_addr) - lengths.append(region_len) + for i, cache in enumerate(cache_or_caches, 0): + base_addr = cache.data_ptr() + region_len = self.num_blocks * self.block_len[i % length] + kv_caches_base_addr.append(base_addr) + ptrs.append(base_addr) + lengths.append(region_len) global_te.register_buffer(ptrs, lengths) self.kv_caches_base_addr = kv_caches_base_addr