[P/D][BugFix]Fix proxy format processing errors & Layerwise connector performance optimization (#4043)
### What this PR does / why we need it?
1. Fix proxy format processing errors.
2. Layer-wise connector performance optimization.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
By CI.
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com>
This commit is contained in:
@@ -447,7 +447,7 @@ def get_api_request_id(api, req_id):
|
|||||||
|
|
||||||
def get_origin_request_id(api, req_id):
|
def get_origin_request_id(api, req_id):
|
||||||
if api == "/completions":
|
if api == "/completions":
|
||||||
return req_id.replace("cmpl-", "").replace("-0", "")
|
return req_id.replace("cmpl-", "")[:-2]
|
||||||
elif api == "/chat/completions":
|
elif api == "/chat/completions":
|
||||||
return req_id.replace("chatcmpl-", "")
|
return req_id.replace("chatcmpl-", "")
|
||||||
|
|
||||||
|
|||||||
@@ -32,6 +32,14 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
|
|||||||
self.engine = MagicMock()
|
self.engine = MagicMock()
|
||||||
self.engine.register_memory.return_value = 0
|
self.engine.register_memory.return_value = 0
|
||||||
self.engine.batch_transfer_sync_write.return_value = 1
|
self.engine.batch_transfer_sync_write.return_value = 1
|
||||||
|
self._patcher_cs = patch(
|
||||||
|
'vllm_ascend.distributed.mooncake_layerwise_connector.torch_npu.npu.current_stream'
|
||||||
|
)
|
||||||
|
self.mock_current_stream = self._patcher_cs.start()
|
||||||
|
self.addCleanup(self._patcher_cs.stop)
|
||||||
|
fake_stream = MagicMock(name="FakeStream")
|
||||||
|
fake_stream.synchronize = MagicMock()
|
||||||
|
self.mock_current_stream.return_value = fake_stream
|
||||||
|
|
||||||
self.first_kv_cache = torch.zeros((2, 2, 2, 8),
|
self.first_kv_cache = torch.zeros((2, 2, 2, 8),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import msgspec
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
import torch
|
import torch
|
||||||
|
import torch_npu
|
||||||
import zmq
|
import zmq
|
||||||
from mooncake.engine import TransferEngine # type: ignore
|
from mooncake.engine import TransferEngine # type: ignore
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@@ -93,6 +94,8 @@ class KVCacheSendingLayerThread(threading.Thread):
|
|||||||
self.total_layers = total_layers
|
self.total_layers = total_layers
|
||||||
self.use_mla = use_mla
|
self.use_mla = use_mla
|
||||||
self.block_len = block_len
|
self.block_len = block_len
|
||||||
|
self.model_stream = torch_npu.npu.current_stream()
|
||||||
|
self.current_layer = -1
|
||||||
|
|
||||||
if self.pd_head_ratio > 1:
|
if self.pd_head_ratio > 1:
|
||||||
# regesit kv buffer for tp inequal
|
# regesit kv buffer for tp inequal
|
||||||
@@ -192,7 +195,9 @@ class KVCacheSendingLayerThread(threading.Thread):
|
|||||||
src_list.append(src)
|
src_list.append(src)
|
||||||
dst_list.append(dst)
|
dst_list.append(dst)
|
||||||
length_list.append(length)
|
length_list.append(length)
|
||||||
torch.npu.synchronize()
|
if self.current_layer != layer_index:
|
||||||
|
self.current_layer = layer_index
|
||||||
|
self.model_stream.synchronize()
|
||||||
ret = self.engine.batch_transfer_sync_write(
|
ret = self.engine.batch_transfer_sync_write(
|
||||||
session_id, src_list, dst_list, length_list)
|
session_id, src_list, dst_list, length_list)
|
||||||
if ret < 0:
|
if ret < 0:
|
||||||
@@ -243,7 +248,7 @@ class KVCacheSendingLayerThread(threading.Thread):
|
|||||||
((self.tp_rank // self.num_head_replica) %
|
((self.tp_rank // self.num_head_replica) %
|
||||||
self.pd_head_ratio))
|
self.pd_head_ratio))
|
||||||
src_layer_addr += length
|
src_layer_addr += length
|
||||||
torch.npu.synchronize()
|
self.model_stream.synchronize()
|
||||||
ret = self.engine.batch_transfer_sync_write(
|
ret = self.engine.batch_transfer_sync_write(
|
||||||
session_id, src_list, dst_list, length_list)
|
session_id, src_list, dst_list, length_list)
|
||||||
if ret < 0:
|
if ret < 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user