From c2d58c06551d7d82373ca6e4a46217f218906629 Mon Sep 17 00:00:00 2001 From: wangxiaoteng888 <56506195+wangxiaoteng888@users.noreply.github.com> Date: Sun, 9 Nov 2025 09:55:10 +0800 Subject: [PATCH] [P/D][BugFix][v0.11.0-dev]Fix proxy format processing errors & Layerwise connector performance optimization (#4069) ### What this PR does / why we need it? 1.Fix proxy format processing errors. 2.Layer-wise connector performance optimization Signed-off-by: wangxiaoteng --- .../load_balance_proxy_layerwise_server_example.py | 2 +- .../ut/kv_connector/test_mooncake_layerwise_connector.py | 8 ++++++++ vllm_ascend/distributed/mooncake_layerwise_connector.py | 9 +++++++-- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py b/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py index 82e06e0..8bbc359 100644 --- a/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py +++ b/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py @@ -447,7 +447,7 @@ def get_api_request_id(api, req_id): def get_origin_request_id(api, req_id): if api == "/completions": - return req_id.replace("cmpl-", "").replace("-0", "") + return req_id.replace("cmpl-", "")[:-2] elif api == "/chat/completions": return req_id.replace("chatcmpl-", "") diff --git a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py index 31611ac..5de4ca8 100644 --- a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py +++ b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py @@ -32,6 +32,14 @@ class TestKVCacheSendingLayerThread(unittest.TestCase): self.engine = MagicMock() self.engine.register_memory.return_value = 0 self.engine.batch_transfer_sync_write.return_value = 1 + self._patcher_cs = patch( + 'vllm_ascend.distributed.mooncake_layerwise_connector.torch_npu.npu.current_stream' + ) + self.mock_current_stream = self._patcher_cs.start() + self.addCleanup(self._patcher_cs.stop) + fake_stream = MagicMock(name="FakeStream") + fake_stream.synchronize = MagicMock() + self.mock_current_stream.return_value = fake_stream self.first_kv_cache = torch.zeros((2, 2, 2, 8), dtype=torch.float32, diff --git a/vllm_ascend/distributed/mooncake_layerwise_connector.py b/vllm_ascend/distributed/mooncake_layerwise_connector.py index 79d9bdb..ff3e85f 100644 --- a/vllm_ascend/distributed/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/mooncake_layerwise_connector.py @@ -19,6 +19,7 @@ import msgspec import numpy as np import numpy.typing as npt import torch +import torch_npu import zmq from mooncake.engine import TransferEngine # type: ignore from vllm.config import VllmConfig @@ -87,6 +88,8 @@ class KVCacheSendingLayerThread(threading.Thread): self.total_layers = total_layers self.use_mla = use_mla self.block_len = block_len + self.model_stream = torch_npu.npu.current_stream() + self.current_layer = -1 if self.pd_head_ratio > 1: # regesit kv buffer for tp inequal @@ -186,7 +189,9 @@ class KVCacheSendingLayerThread(threading.Thread): src_list.append(src) dst_list.append(dst) length_list.append(length) - torch.npu.synchronize() + if self.current_layer != layer_index: + self.current_layer = layer_index + self.model_stream.synchronize() ret = self.engine.batch_transfer_sync_write( session_id, src_list, dst_list, length_list) if ret < 0: @@ -237,7 +242,7 @@ class KVCacheSendingLayerThread(threading.Thread): ((self.tp_rank // self.num_head_replica) % self.pd_head_ratio)) src_layer_addr += length - torch.npu.synchronize() + self.model_stream.synchronize() ret = self.engine.batch_transfer_sync_write( session_id, src_list, dst_list, length_list) if ret < 0: