[P/D] Performance enhancement of Layerwise connector in TP asymmetric scenarios (#5540)

### What this PR does / why we need it?
[P/D] Performance enhancement of Layerwise connector in TP asymmetric
scenarios
1. Session fusion: For transmission tasks at each layer, aggregate
transmission tasks with the same destination and merge them into a
single task for assignment.
2. Alltoall aggregation: For TP asymmetric scenarios, perform all
alltoall operations at once according to the block granularity for all
requests.

[RFC]: CDCP Scheduling for Disaggregated Prefilling with KV Cache
Layerwise Push Support
https://github.com/vllm-project/vllm-ascend/issues/4842
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1

---------

Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Co-authored-by: nwpu-zxr <zhouxuerong2@huawei.com>
Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com>
This commit is contained in:
liziyu
2026-01-06 20:25:36 +08:00
committed by GitHub
parent cd1162e25a
commit 330e25ab1d
2 changed files with 341 additions and 275 deletions

View File

@@ -1,3 +1,4 @@
import contextlib
import os import os
import sys import sys
import threading import threading
@@ -18,9 +19,9 @@ from vllm_ascend.distributed.mooncake_layerwise_connector import ( # noqa: E402
KVCacheRecvingLayerThread, KVCacheSendingLayerThread, KVConnectorRole, KVCacheRecvingLayerThread, KVCacheSendingLayerThread, KVConnectorRole,
MooncakeAgentMetadata, MooncakeLayerwiseConnector, MooncakeAgentMetadata, MooncakeLayerwiseConnector,
MooncakeLayerwiseConnectorMetadata, MooncakeLayerwiseConnectorScheduler, MooncakeLayerwiseConnectorMetadata, MooncakeLayerwiseConnectorScheduler,
MooncakeLayerwiseConnectorWorker, ReqMeta, SendReqInfo, ensure_zmq_recv, MooncakeLayerwiseConnectorWorker, ReqMeta, SendReqInfo, SendTask,
ensure_zmq_send, group_concurrent_contiguous, string_to_int64_hash, ensure_zmq_recv, ensure_zmq_send, group_concurrent_contiguous,
zmq_ctx) string_to_int64_hash, zmq_ctx)
GET_META_MSG = b"get_meta_msg" GET_META_MSG = b"get_meta_msg"
DONE_SENDING_MSG = b"done_sending_msg" DONE_SENDING_MSG = b"done_sending_msg"
@@ -32,14 +33,8 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
self.engine = MagicMock() self.engine = MagicMock()
self.engine.register_memory.return_value = 0 self.engine.register_memory.return_value = 0
self.engine.batch_transfer_sync_write.return_value = 1 self.engine.batch_transfer_sync_write.return_value = 1
self._patcher_cs = patch(
'vllm_ascend.distributed.mooncake_layerwise_connector.torch_npu.npu.current_stream'
)
self.mock_current_stream = self._patcher_cs.start()
self.addCleanup(self._patcher_cs.stop)
fake_stream = MagicMock(name="FakeStream") fake_stream = MagicMock(name="FakeStream")
fake_stream.synchronize = MagicMock() fake_stream.synchronize = MagicMock()
self.mock_current_stream.return_value = fake_stream
self.first_kv_cache = torch.zeros((2, 2, 2, 8), self.first_kv_cache = torch.zeros((2, 2, 2, 8),
dtype=torch.float32, dtype=torch.float32,
@@ -47,6 +42,14 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
self.ready_event = threading.Event() self.ready_event = threading.Event()
self.fake_k_buffer = MagicMock()
self.fake_v_buffer = MagicMock()
fake_resharding_stream = MagicMock()
cap = self.first_kv_cache.numel() // self.first_kv_cache.shape[-1]
dim = self.first_kv_cache.shape[-1]
self.key = torch.zeros((cap, dim), dtype=torch.float32)
self.value = torch.zeros((cap, dim), dtype=torch.float32)
self.thread = KVCacheSendingLayerThread( self.thread = KVCacheSendingLayerThread(
engine=self.engine, engine=self.engine,
total_layers=3, total_layers=3,
@@ -60,6 +63,9 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
block_len=[1024, 2048], block_len=[1024, 2048],
decode_tp_size=1, decode_tp_size=1,
first_kv_cache=self.first_kv_cache, first_kv_cache=self.first_kv_cache,
k_buffer=self.fake_k_buffer,
v_buffer=self.fake_v_buffer,
resharding_stream=fake_resharding_stream,
callback_func=MagicMock()) callback_func=MagicMock())
self.req_meta_base = ReqMeta( self.req_meta_base = ReqMeta(
@@ -74,6 +80,9 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
metaserver="http://dummy", metaserver="http://dummy",
chunk_finish=False) chunk_finish=False)
@patch(
"vllm_ascend.distributed.mooncake_layerwise_connector.npu_stream_switch",
side_effect=lambda *_args, **_kwargs: contextlib.nullcontext())
@patch( @patch(
"vllm_ascend.distributed.mooncake_layerwise_connector.torch.Tensor.data_ptr", "vllm_ascend.distributed.mooncake_layerwise_connector.torch.Tensor.data_ptr",
autospec=True, autospec=True,
@@ -87,7 +96,10 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
"vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous" "vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous"
) )
def test_transfer_pd_gt1_uses_buffers_and_calls_engine( def test_transfer_pd_gt1_uses_buffers_and_calls_engine(
self, mock_group, _mock_sync, _mock_align, _mock_dataptr): self, mock_group, _mock_sync, _mock_align, _mock_dataptr,
mock_stream_switch):
fake_resharding_stream = MagicMock()
thread = KVCacheSendingLayerThread( thread = KVCacheSendingLayerThread(
engine=self.engine, engine=self.engine,
@@ -101,26 +113,28 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
block_len=[64], block_len=[64],
decode_tp_size=1, decode_tp_size=1,
first_kv_cache=self.first_kv_cache, first_kv_cache=self.first_kv_cache,
k_buffer=self.fake_k_buffer,
v_buffer=self.fake_v_buffer,
resharding_stream=fake_resharding_stream,
callback_func=MagicMock()) callback_func=MagicMock())
req_meta = self.req_meta_base req_meta = self.req_meta_base
req_meta.remote_kv_caches_base_addr = [4000, 8000] req_meta.remote_kv_caches_base_addr = [4000, 8000]
mock_group.return_value = ([[10, 11], [20, 21]], []) mock_group.return_value = ([[10, 11], [20, 21]], [])
key = torch.zeros((1, 8), dtype=torch.float32)
value = torch.zeros((1, 8), dtype=torch.float32)
cap = self.first_kv_cache.numel() // self.first_kv_cache.shape[-1] send_task = SendTask(
dim = self.first_kv_cache.shape[-1] send_request={"req1": req_meta},
wait_event=MagicMock(),
k_cache=key,
v_cache=value,
layer_idx=0,
rearrange_block_ids=[5, 8],
)
key = torch.zeros((cap, dim), dtype=torch.float32) thread._transfer_kv_cache(send_task)
value = torch.zeros((cap, dim), dtype=torch.float32)
thread._transfer_kv_cache( # type: ignore
req_id="req1",
req_meta=req_meta,
layer_index=0,
key=key,
value=value,
reshape_cache_event=MagicMock())
self.engine.batch_transfer_sync_write.assert_called_once() self.engine.batch_transfer_sync_write.assert_called_once()
session_id, src_list, dst_list, length_list = self.engine.batch_transfer_sync_write.call_args[ session_id, src_list, dst_list, length_list = self.engine.batch_transfer_sync_write.call_args[
@@ -145,60 +159,45 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
def test_transfer_skips_when_no_local_blocks(self): def test_transfer_skips_when_no_local_blocks(self):
req_meta = self.req_meta_base req_meta = self.req_meta_base
req_meta.local_block_ids = [] req_meta.local_block_ids = []
self.thread.pd_head_ratio = 1 send_task = SendTask(
self.thread.block_len = [64, 128] send_request={"req2": req_meta},
wait_event=MagicMock(),
key = torch.zeros((1, 8), dtype=torch.float32) k_cache=torch.zeros((1, 8)),
value = torch.zeros((1, 8), dtype=torch.float32) v_cache=torch.zeros((1, 8)),
layer_idx=0,
reshape_cache_event = MagicMock() rearrange_block_ids=[],
with patch.object(self.engine, )
'batch_transfer_sync_write') as mock_batch_transfer: self.thread._transfer_kv_cache(send_task)
mock_batch_transfer.return_value = 1 self.engine.batch_transfer_sync_write.assert_not_called()
def _mock_transfer_kv_cache(req_id, req_meta, layer_index, key,
value,
reshape_cache_event): # type: ignore
if not req_meta.local_block_ids:
return
self._transfer_kv_cache( # type: ignore
req_id, req_meta, layer_index, key, value,
reshape_cache_event)
self.thread._transfer_kv_cache = _mock_transfer_kv_cache # type: ignore
self.thread._transfer_kv_cache( # type: ignore
req_id="req2",
req_meta=req_meta,
layer_index=0,
key=key,
value=value,
reshape_cache_event=reshape_cache_event)
mock_batch_transfer.assert_not_called()
self.assertEqual(mock_batch_transfer.call_count, 0)
def test_transfer_skips_when_tp_not_sender(self): def test_transfer_skips_when_tp_not_sender(self):
thread = KVCacheSendingLayerThread(engine=self.engine, thread = KVCacheSendingLayerThread(
total_layers=2, engine=self.engine,
ready_event=self.ready_event, total_layers=2,
tp_rank=1, ready_event=self.ready_event,
pd_head_ratio=1, tp_rank=1,
num_head_replica=2, pd_head_ratio=1,
kv_cache_base_addr=[1000, 2000], num_head_replica=2,
use_mla=False, kv_cache_base_addr=[1000, 2000, 3000, 4000],
block_len=[1024], use_mla=False,
decode_tp_size=1, block_len=[1024],
first_kv_cache=self.first_kv_cache, decode_tp_size=1,
callback_func=MagicMock()) first_kv_cache=self.first_kv_cache,
k_buffer=MagicMock(),
v_buffer=MagicMock(),
resharding_stream=MagicMock(),
callback_func=MagicMock())
req_meta = self.req_meta_base req_meta = self.req_meta_base
thread._transfer_kv_cache( # type: ignore send_task = SendTask(
"req3", send_request={"req3": req_meta},
req_meta, wait_event=MagicMock(),
0, k_cache=self.key,
torch.zeros((1, 8)), v_cache=self.value,
torch.zeros((1, 8)), layer_idx=1,
reshape_cache_event=MagicMock()) rearrange_block_ids=[],
)
thread._transfer_kv_cache(send_task)
self.engine.batch_transfer_sync_write.assert_not_called() self.engine.batch_transfer_sync_write.assert_not_called()
@patch( @patch(
@@ -208,30 +207,30 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
"vllm_ascend.distributed.mooncake_layerwise_connector.torch.npu.synchronize" "vllm_ascend.distributed.mooncake_layerwise_connector.torch.npu.synchronize"
) )
def test_callback_invoked_on_final_layer(self, _mock_sync, _mock_group): def test_callback_invoked_on_final_layer(self, _mock_sync, _mock_group):
req_meta = self.req_meta_base req_meta = self.req_meta_base
req_meta.chunk_finish = True
req_meta.local_block_ids = [5, 6] req_meta.local_block_ids = [5, 6]
req_meta.remote_block_ids = [10, 11] req_meta.remote_block_ids = [10, 11]
req_meta.remote_kv_caches_base_addr = [ req_meta.remote_kv_caches_base_addr = [
7000, 8000, 9000, 10000, 11000, 12000 7000, 8000, 9000, 10000, 11000, 12000
] ]
req_meta.chunk_finish = True
key = torch.zeros((1, 8), dtype=torch.float32) key = torch.zeros((1, 8), dtype=torch.float32)
value = torch.zeros((1, 8), dtype=torch.float32) value = torch.zeros((1, 8), dtype=torch.float32)
send_task = MagicMock() send_task = SendTask(
send_task.layer_index = self.thread.total_layers - 1 send_request={"req5": req_meta},
send_task.send_request = {"req5": req_meta} wait_event=MagicMock(),
k_cache=key,
v_cache=value,
layer_idx=2,
rearrange_block_ids=[],
)
self.thread._transfer_kv_cache(send_task)
with patch.object(self.thread, 'callback_func') as mock_callback_func: self.thread.callback_func.assert_called_once()
self.thread._transfer_kv_cache( # type: ignore
req_id="req5",
req_meta=req_meta,
layer_index=send_task.layer_index,
key=key,
value=value,
reshape_cache_event=MagicMock())
print(f"Callback called: {mock_callback_func.call_count} times")
mock_callback_func.assert_called_once()
class TestKVCacheRecvingLayerThread(unittest.TestCase): class TestKVCacheRecvingLayerThread(unittest.TestCase):
@@ -506,10 +505,10 @@ class TestMooncakeLayerwiseConnectorSchedulerMatchedTokens(unittest.TestCase):
self.assertTrue(async_flag) self.assertTrue(async_flag)
def test_build_connector_meta(self): def test_build_connector_meta(self):
self.scheduler.vllm_config.kv_transfer_config.is_kv_consumer = True
request = MockRequest("req1") request = MockRequest("req1")
self.scheduler._reqs_need_recv["req1"] = (request, [], [4, 5, 6]) self.scheduler._reqs_need_recv["req1"] = (request, [], [4, 5, 6])
self.scheduler.vllm_config.kv_transfer_config.is_kv_consumer = True
request.kv_transfer_params = { request.kv_transfer_params = {
"remote_block_ids": [1, 2, 3], "remote_block_ids": [1, 2, 3],
"remote_engine_id": "remote", "remote_engine_id": "remote",
@@ -554,9 +553,9 @@ class _MockSchedulerOutput:
new_block_ids=cached_new_block_ids or [], new_block_ids=cached_new_block_ids or [],
num_computed_tokens=cached_num_computed or [], num_computed_tokens=cached_num_computed or [],
) )
self.scheduled_spec_decode_tokens = scheduled_spec_decode_tokens or {}
self.scheduled_new_reqs = new_reqs or [] self.scheduled_new_reqs = new_reqs or []
self.num_scheduled_tokens = num_sched or {} self.num_scheduled_tokens = num_sched or {}
self.scheduled_spec_decode_tokens = scheduled_spec_decode_tokens or {}
class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase): class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
@@ -593,14 +592,12 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
self.assertFalse(req.kv_transfer_params.get("do_remote_prefill", True)) self.assertFalse(req.kv_transfer_params.get("do_remote_prefill", True))
def test_update_state_after_alloc_decode_records_send_layerwise(self): def test_update_state_after_alloc_decode_records_send_layerwise(self):
req = MockRequest( req = MockRequest("req_u2",
"req_u2", prompt_token_ids=list(range(10)),
prompt_token_ids=list(range(10)), kv_transfer_params={
kv_transfer_params={ "do_remote_decode": True,
"do_remote_decode": True, "remote_block_ids": []
"remote_block_ids": [] # 修改为空列表 [] })
})
blocks = _MockBlocks(unhashed=[], block_ids_tuple=([7, 8, 9], )) blocks = _MockBlocks(unhashed=[], block_ids_tuple=([7, 8, 9], ))
self.scheduler.update_state_after_alloc(req, self.scheduler.update_state_after_alloc(req,
blocks, blocks,
@@ -610,7 +607,24 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
self.assertEqual(info.local_block_ids, [7, 8, 9]) self.assertEqual(info.local_block_ids, [7, 8, 9])
self.assertIs(info.request, req) self.assertIs(info.request, req)
self.assertEqual(info.remote_block_ids, []) self.assertEqual(info.remote_block_ids, [])
self.assertIsInstance(info.remote_block_ids, list)
def test_build_connector_meta_consumes_reqs_need_recv_and_clears(self):
self.scheduler.vllm_config.kv_transfer_config.is_kv_consumer = True
req = MockRequest("req_b1",
kv_transfer_params={
"remote_block_ids": [1, 2],
"remote_engine_id": "E",
"remote_host": "H",
"remote_port": 5555,
"remote_te_rpc_port": 6000,
"remote_kv_caches_base_addr": [10, 11],
})
self.scheduler._reqs_need_recv["req_b1"] = (req, [], [100, 101])
meta = self.scheduler.build_connector_meta(_MockSchedulerOutput())
self.assertIsInstance(meta, MooncakeLayerwiseConnectorMetadata)
self.assertIn("req_b1", meta.requests)
self.assertEqual(meta.requests["req_b1"].local_block_ids, [100, 101])
self.assertEqual(len(self.scheduler._reqs_need_recv), 0)
def test_build_connector_meta_accumulates_cached_blocks(self): def test_build_connector_meta_accumulates_cached_blocks(self):
req_meta = MagicMock(spec=ReqMeta) req_meta = MagicMock(spec=ReqMeta)
@@ -637,8 +651,6 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
meta = self.scheduler.build_connector_meta(out) meta = self.scheduler.build_connector_meta(out)
self.assertEqual(len(meta.requests), 0) self.assertEqual(len(meta.requests), 0)
req_meta.extend_local_block_ids.assert_called_once_with([3, 4])
@patch( @patch(
"vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous" "vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous"
) )
@@ -684,6 +696,12 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
send_req_info.extend_local_block_ids.assert_called_once_with([50]) send_req_info.extend_local_block_ids.assert_called_once_with([50])
self.assertIn("req_b3", meta.requests) self.assertIn("req_b3", meta.requests)
def test_request_finished_returns_false_none(self):
ok, params = self.scheduler.request_finished(MockRequest("req_fin"),
[1, 2])
self.assertFalse(ok)
self.assertIsNone(params)
class TestHelperFunctions(unittest.TestCase): class TestHelperFunctions(unittest.TestCase):

View File

@@ -11,7 +11,7 @@ import time
from collections import OrderedDict, defaultdict, deque from collections import OrderedDict, defaultdict, deque
from collections.abc import Iterator from collections.abc import Iterator
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
import httpx import httpx
@@ -19,7 +19,6 @@ import msgspec
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
import torch import torch
import torch_npu
import zmq import zmq
from mooncake.engine import TransferEngine # type: ignore from mooncake.engine import TransferEngine # type: ignore
from vllm.config import VllmConfig from vllm.config import VllmConfig
@@ -37,6 +36,7 @@ from vllm_ascend.distributed.mooncake_transfer_engine import global_te
from vllm_ascend.distributed.utils import (align_memory, from vllm_ascend.distributed.utils import (align_memory,
get_transfer_timeout_value, get_transfer_timeout_value,
kv_alltoall_and_rearrange) kv_alltoall_and_rearrange)
from vllm_ascend.utils import npu_stream_switch
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
@@ -56,7 +56,7 @@ class MooncakeAgentMetadata(msgspec.Struct, omit_defaults=True, dict=True):
@dataclass @dataclass
class ReqMeta: class ReqMeta:
local_block_ids: list[int] local_block_ids: list[int]
token_ids: list[int] token_ids: Optional[list[int]]
# Not None if layer-wise is disabled # Not None if layer-wise is disabled
remote_block_ids: list[int] remote_block_ids: list[int]
remote_engine_id: Optional[str] remote_engine_id: Optional[str]
@@ -68,6 +68,26 @@ class ReqMeta:
chunk_finish: Optional[bool] chunk_finish: Optional[bool]
@dataclass
class SendTask:
send_request: dict[str, ReqMeta] = field(default_factory=dict)
# pd_head_ratio == 1 use
wait_event: Optional[torch.npu.Event] = None
# pd_head_ratio > 1 use
k_cache: Optional[torch.Tensor] = None
v_cache: Optional[torch.Tensor] = None
layer_idx: int = 0
rearrange_block_ids: Optional[list[int]] = None
@dataclass
class TransferMeta:
src: list[int]
dst: list[int]
length: list[int]
req_ids: list[str]
@dataclass @dataclass
class SendReqInfo: class SendReqInfo:
local_block_ids: list[int] local_block_ids: list[int]
@@ -116,19 +136,24 @@ class SizedDict(OrderedDict):
class KVCacheSendingLayerThread(threading.Thread): class KVCacheSendingLayerThread(threading.Thread):
def __init__(self, def __init__(
engine: TransferEngine, self,
total_layers: int, engine: TransferEngine,
ready_event: threading.Event, total_layers: int,
tp_rank: int, ready_event: threading.Event,
pd_head_ratio: int, tp_rank: int,
num_head_replica: int, pd_head_ratio: int,
kv_cache_base_addr: list[int], num_head_replica: int,
use_mla: bool, kv_cache_base_addr: list[int],
block_len: list[int], use_mla: bool,
decode_tp_size: int, block_len: list[int],
first_kv_cache: torch.Tensor, decode_tp_size: int,
callback_func: Callable[..., None] = lambda x: None): first_kv_cache: torch.Tensor,
k_buffer: torch.Tensor,
v_buffer: torch.Tensor,
resharding_stream: torch.npu.Stream,
callback_func: Callable[..., None] = lambda x: None,
):
super().__init__(daemon=True, name="KVCacheSendingLayerThread") super().__init__(daemon=True, name="KVCacheSendingLayerThread")
self.engine = engine self.engine = engine
self.tp_rank = tp_rank self.tp_rank = tp_rank
@@ -139,39 +164,12 @@ class KVCacheSendingLayerThread(threading.Thread):
self.use_mla = use_mla self.use_mla = use_mla
self.block_len = block_len self.block_len = block_len
self._decode_tp_size = decode_tp_size self._decode_tp_size = decode_tp_size
self.model_stream = torch_npu.npu.current_stream() self.resharding_stream = resharding_stream
self.current_layer = -1 self.current_layer = -1
if self.pd_head_ratio > 1: self.send_queue = queue.Queue[SendTask]()
# regesit kv buffer for tp inequal self.k_buffer = k_buffer
alignment = 2 * 1024 * 1024 self.v_buffer = v_buffer
self.k_buffer = torch.zeros(first_kv_cache.numel() + alignment,
dtype=first_kv_cache.dtype,
device=first_kv_cache.device)
self.k_buffer = align_memory(
self.k_buffer, alignment)[:first_kv_cache.numel()].view(
-1, first_kv_cache.shape[-1])
self.v_buffer = torch.zeros(first_kv_cache.numel() + alignment,
dtype=first_kv_cache.dtype,
device=first_kv_cache.device)
self.v_buffer = align_memory(
self.v_buffer, alignment)[:first_kv_cache.numel()].view(
-1, first_kv_cache.shape[-1])
for tensor in (self.k_buffer, self.v_buffer):
assert tensor.data_ptr(
) % alignment == 0, "The address of the registered kv cache should be aligned to 2M"
ret_value = self.engine.register_memory(
tensor.data_ptr(), tensor.numel())
logger.info(
f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} {tensor.numel()} {ret_value=}"
)
if ret_value != 0:
raise RuntimeError("Mooncake memory registration failed. ")
self.send_queue = queue.Queue[Tuple[str, ReqMeta, int, torch.Tensor,
torch.Tensor, torch.npu.Event]]()
self.ready_event = ready_event self.ready_event = ready_event
self.callback_func = callback_func self.callback_func = callback_func
@@ -181,43 +179,36 @@ class KVCacheSendingLayerThread(threading.Thread):
torch.npu.set_device(device) torch.npu.set_device(device)
self.ready_event.set() self.ready_event.set()
while True: while True:
req_id, req_meta, layer_index, key, value, reshape_cache_event = self.send_queue.get( send_task = self.send_queue.get()
) self._handle_request(send_task)
self._handle_request(req_id, req_meta, layer_index, key, value,
reshape_cache_event)
def _handle_request(self, req_id, req_meta, layer_index, key, value, def _handle_request(self, send_task: SendTask):
reshape_cache_event):
try: try:
logger.debug( self._transfer_kv_cache(send_task)
f"Starting to transfer KV cache for request {req_id} {req_meta.remote_te_rpc_port=}."
)
self._transfer_kv_cache(req_id, req_meta, layer_index, key, value,
reshape_cache_event)
logger.debug(
f"Finished transferring KV cache for request {req_id} {req_meta.remote_te_rpc_port=}."
)
except Exception as e: except Exception as e:
logger.error("Failed to transfer KV cache for request " logger.error(
f"{req_id}: {e}") f"Failed to transfer KV cache for layer idx {send_task.layer_idx}, {e}"
)
def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value, def get_transfer_meta(self, send_task: SendTask, req_id: str,
reshape_cache_event): req_meta: ReqMeta):
src_list: list[str] = []
dst_list: list[str] = []
length_list: list[int] = []
# not need to send kv cache # not need to send kv cache
if self.tp_rank % self.num_head_replica != 0: if self.tp_rank % self.num_head_replica != 0:
logger.debug( logger.debug(
f"Cancelling KV cache transfer for request {req_id}. Reason: TP rank excluded from head replication (TP Rank: {self.tp_rank}, Replicas: {self.num_head_replica})." f"Cancelling KV cache transfer for request {req_id}. Reason: TP rank excluded from head replication (TP Rank: {self.tp_rank}, Replicas: {self.num_head_replica})."
) )
return return (src_list, dst_list, length_list)
if self.use_mla and self.tp_rank >= self._decode_tp_size: if self.use_mla and self.tp_rank >= self._decode_tp_size:
logger.debug( logger.debug(
f"Cancelling KV cache transfer for request {req_id}. Reason: MLA mode active and TP rank outside decoding group (TP Rank: {self.tp_rank}, Decode TP Size: {self._decode_tp_size})." f"Cancelling KV cache transfer for request {req_id}. Reason: MLA mode active and TP rank outside decoding group (TP Rank: {self.tp_rank}, Decode TP Size: {self._decode_tp_size})."
) )
return return (src_list, dst_list, length_list)
remote_host = req_meta.remote_host layer_idx = send_task.layer_idx
remote_block_ids = req_meta.remote_block_ids remote_block_ids = req_meta.remote_block_ids
remote_te_port = req_meta.remote_te_rpc_port
remote_kv_base_addrs = req_meta.remote_kv_caches_base_addr remote_kv_base_addrs = req_meta.remote_kv_caches_base_addr
local_kv_base_addr = self.kv_caches_base_addr local_kv_base_addr = self.kv_caches_base_addr
local_block_ids = req_meta.local_block_ids local_block_ids = req_meta.local_block_ids
@@ -225,17 +216,15 @@ class KVCacheSendingLayerThread(threading.Thread):
if self.pd_head_ratio == 1: if self.pd_head_ratio == 1:
layer_local_kv_base_addr = [ layer_local_kv_base_addr = [
local_kv_base_addr[i] local_kv_base_addr[i]
for i in [2 * layer_index, 2 * layer_index + 1] for i in [2 * layer_idx, 2 * layer_idx + 1]
] ]
layer_remote_kv_base_addr = [ layer_remote_kv_base_addr = [
remote_kv_base_addrs[i] remote_kv_base_addrs[i] # type:ignore
for i in [2 * layer_index, 2 * layer_index + 1] for i in [2 * layer_idx, 2 * layer_idx + 1]
] ]
grouped_remote_block_ids, grouped_local_block_ids = \ grouped_remote_block_ids, grouped_local_block_ids = \
group_concurrent_contiguous(remote_block_ids, local_block_ids) group_concurrent_contiguous(remote_block_ids, local_block_ids)
session_id = f"{remote_host}:{remote_te_port}"
src_list, dst_list, length_list = [], [], []
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)): zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)):
block_len = self.block_len[ block_len = self.block_len[
@@ -250,74 +239,101 @@ class KVCacheSendingLayerThread(threading.Thread):
src_list.append(src) src_list.append(src)
dst_list.append(dst) dst_list.append(dst)
length_list.append(length) length_list.append(length)
if self.current_layer != layer_index:
self.current_layer = layer_index
"""
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
This issue will be fixed in CANN version 8.5.rc1.
You can manually build the master branch of the project at https://gitcode.com/cann/hixl
to resolve this issue before the 8.5.RC1 release.
"""
reshape_cache_event.synchronize()
ret = self.engine.batch_transfer_sync_write(
session_id, src_list, dst_list, length_list)
if ret < 0:
logger.error("Mooncake transfer failed for request %s", req_id)
raise RuntimeError(f"Mooncake transfer failed, ret: {ret}")
else: else:
key = key.view(-1, key.shape[-1]) rearrange_block_ids = send_task.rearrange_block_ids
value = value.view(-1, key.shape[-1]) rearrange_block_dict = {
self.k_buffer[:key.shape[0]].copy_(key) # [:4, 128] -> value: index
self.v_buffer[:value.shape[0]].copy_(value) for index, value in enumerate(
rearrange_block_ids) # type:ignore
}
layer_local_kv_base_addr = [ layer_local_kv_base_addr = [
self.k_buffer.data_ptr(), self.k_buffer.data_ptr(),
self.v_buffer.data_ptr() self.v_buffer.data_ptr()
] ]
layer_remote_kv_base_addr = [ layer_remote_kv_base_addr = [
remote_kv_base_addrs[i] remote_kv_base_addrs[i] # type:ignore
for i in [2 * layer_index, 2 * layer_index + 1] for i in [2 * layer_idx, 2 * layer_idx + 1]
] ]
grouped_remote_block_ids, _ = group_concurrent_contiguous(
remote_block_ids)
session_id = f"{remote_host}:{remote_te_port}"
src_list, dst_list, length_list = [], [], [] src_list, dst_list, length_list = [], [], []
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)): zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)):
src_layer_addr = src_layer_base_addr block_len = self.block_len[0]
for group_remote_block_id in grouped_remote_block_ids: remote_block_len = self.block_len[0] * self.pd_head_ratio
block_len = self.block_len[0] for remote_block_id, local_block_id in zip(
remote_block_len = self.block_len[0] * self.pd_head_ratio remote_block_ids, local_block_ids):
src_list.append(src_layer_addr) src = src_layer_base_addr + rearrange_block_dict[
local_block_id] * block_len
dst = dst_layer_base_addr + remote_block_id * remote_block_len + block_len * (
(self.tp_rank // self.num_head_replica) %
self.pd_head_ratio)
src_list.append(src)
dst_list.append(dst)
length_list.append(block_len)
return (src_list, dst_list, length_list)
if src_layer_addr + len( def _transfer_kv_cache(self, send_task: SendTask):
group_remote_block_id if self.pd_head_ratio > 1:
) * block_len > src_layer_base_addr + key.numel( with npu_stream_switch(self.resharding_stream):
) * key.element_size(): key = send_task.k_cache
length = src_layer_base_addr + key.numel( value = send_task.v_cache
) * key.element_size() - src_layer_addr key = key.view(-1, key.shape[-1]) # type:ignore
else: value = value.view(-1, key.shape[-1]) # type:ignore
length = len(group_remote_block_id) * block_len self.k_buffer[:key.shape[0]].copy_(key) # [:4, 128] ->
length_list.append(length) self.v_buffer[:value.shape[0]].copy_(value)
dst_list.append(dst_layer_base_addr + # Merge transmission tasks of the same session
group_remote_block_id[0] * session_meta: dict[str, TransferMeta] = {}
remote_block_len + length * for req_id, req_meta in send_task.send_request.items():
((self.tp_rank // self.num_head_replica) % session_id = f"{req_meta.remote_host}:{req_meta.remote_te_rpc_port}"
self.pd_head_ratio)) if session_id not in session_meta.keys():
src_layer_addr += length session_meta[session_id] = TransferMeta(src=[],
self.model_stream.synchronize() dst=[],
ret = self.engine.batch_transfer_sync_write( length=[],
session_id, src_list, dst_list, length_list) req_ids=[])
if ret < 0:
logger.error("Mooncake transfer failed for request %s", req_id)
raise RuntimeError(f"Mooncake transfer failed, ret: {ret}")
if layer_index == (self.total_layers - 1) and req_meta.chunk_finish: (src_list, dst_list,
self.callback_func(req_id, req_meta) length_list) = self.get_transfer_meta(send_task, req_id, req_meta)
session_meta[session_id].src.extend(src_list)
session_meta[session_id].dst.extend(dst_list)
session_meta[session_id].length.extend(length_list)
session_meta[session_id].req_ids.append(req_id)
if self.pd_head_ratio == 1:
"""
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
This issue will be fixed in CANN version 8.5.rc1.
You can manually build the master branch of the project at https://gitcode.com/cann/hixl
to resolve this issue before the 8.5.RC1 release.
"""
send_task.wait_event.synchronize() # type:ignore
elif self.pd_head_ratio > 1:
self.resharding_stream.synchronize()
for session_id, transfer_meta in session_meta.items():
if len(transfer_meta.src) > 0:
ret = self.engine.batch_transfer_sync_write(
session_id, transfer_meta.src, transfer_meta.dst,
transfer_meta.length)
if ret < 0:
logger.error(
f"Mooncake transfer failed for send requests {transfer_meta.req_ids} kv cache to {session_id}"
)
if send_task.layer_idx == (self.total_layers - 1):
for req_id in transfer_meta.req_ids:
req_meta = send_task.send_request[req_id]
if req_meta.chunk_finish:
self.callback_func(
req_id, req_meta
) # TODO Send a signal indicating transmission failure
else:
if send_task.layer_idx == (self.total_layers - 1):
for req_id in transfer_meta.req_ids:
req_meta = send_task.send_request[req_id]
if req_meta.chunk_finish:
self.callback_func(req_id, req_meta)
class KVCacheRecvingLayerThread(threading.Thread): class KVCacheRecvingLayerThread(threading.Thread):
@@ -836,8 +852,10 @@ class MooncakeLayerwiseConnectorWorker:
self.pd_tp_ratio = get_ascend_config().pd_tp_ratio self.pd_tp_ratio = get_ascend_config().pd_tp_ratio
self.pd_head_ratio = get_ascend_config().pd_head_ratio self.pd_head_ratio = get_ascend_config().pd_head_ratio
self.num_head_replica = get_ascend_config().num_head_replica self.num_head_replica = get_ascend_config().num_head_replica
self.resharding_stream = None
if self.pd_head_ratio > 1:
self.resharding_stream = torch.npu.Stream()
self.first_kv_cache = None
self.remote_poller = zmq.Poller() # type: ignore self.remote_poller = zmq.Poller() # type: ignore
self.decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata) self.decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata)
self.encoder = msgspec.msgpack.Encoder() self.encoder = msgspec.msgpack.Encoder()
@@ -852,6 +870,8 @@ class MooncakeLayerwiseConnectorWorker:
deque) deque)
self.remote_poller = zmq.Poller() # type: ignore self.remote_poller = zmq.Poller() # type: ignore
self.timeout = 1.0 # seconds self.timeout = 1.0 # seconds
self.k_buffer: Optional[torch.Tensor] = None
self.v_buffer: Optional[torch.Tensor] = None
def _get_prefill_decode_size(self, vllm_config: VllmConfig): def _get_prefill_decode_size(self, vllm_config: VllmConfig):
# get prefill tp and dp size from extra config # get prefill tp and dp size from extra config
@@ -874,12 +894,40 @@ class MooncakeLayerwiseConnectorWorker:
assert "dp_size" in decode_parallel_config.keys() assert "dp_size" in decode_parallel_config.keys()
self._decode_dp_size = decode_parallel_config["dp_size"] self._decode_dp_size = decode_parallel_config["dp_size"]
def create_kv_buffer(self, first_kv_cache):
if self.pd_head_ratio > 1:
# regesit kv buffer for tp inequal
alignment = 2 * 1024 * 1024
self.k_buffer = torch.zeros(first_kv_cache.numel() + alignment,
dtype=first_kv_cache.dtype,
device=first_kv_cache.device)
self.k_buffer = align_memory(
self.k_buffer, alignment)[:first_kv_cache.numel()].view(
-1, first_kv_cache.shape[-1])
self.v_buffer = torch.zeros(first_kv_cache.numel() + alignment,
dtype=first_kv_cache.dtype,
device=first_kv_cache.device)
self.v_buffer = align_memory(
self.v_buffer, alignment)[:first_kv_cache.numel()].view(
-1, first_kv_cache.shape[-1])
for tensor in (self.k_buffer, self.v_buffer):
assert tensor.data_ptr(
) % alignment == 0, "The address of the registered kv cache should be aligned to 2M"
ret_value = self.engine.register_memory(
tensor.data_ptr(), tensor.numel())
logger.info(
f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} {tensor.numel()} {ret_value=}"
)
if ret_value != 0:
raise RuntimeError("Mooncake memory registration failed. ")
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data.""" """Register the KV Cache data."""
_, first_kv_cache_tuple = next(iter(kv_caches.items())) _, first_kv_cache_tuple = next(iter(kv_caches.items()))
first_kv_cache = first_kv_cache_tuple[0] first_kv_cache = first_kv_cache_tuple[0]
self.first_kv_cache = first_kv_cache self.create_kv_buffer(first_kv_cache)
# TODO(tms): Find a more robust way to detect and handle MLA # TODO(tms): Find a more robust way to detect and handle MLA
self.use_mla = first_kv_cache_tuple[0].size( self.use_mla = first_kv_cache_tuple[0].size(
@@ -954,6 +1002,9 @@ class MooncakeLayerwiseConnectorWorker:
block_len=self.block_len, block_len=self.block_len,
decode_tp_size=self._decode_tp_size, decode_tp_size=self._decode_tp_size,
first_kv_cache=first_kv_cache, first_kv_cache=first_kv_cache,
k_buffer=self.k_buffer,
v_buffer=self.v_buffer,
resharding_stream=self.resharding_stream,
callback_func=self.send_done_send_signal) callback_func=self.send_done_send_signal)
self.kv_send_layer_thread.start() self.kv_send_layer_thread.start()
ready_event.wait() ready_event.wait()
@@ -1002,60 +1053,49 @@ class MooncakeLayerwiseConnectorWorker:
reshape_cache_event = attn_metadata.reshape_cache_event reshape_cache_event = attn_metadata.reshape_cache_event
if self.pd_head_ratio != 1: if self.pd_head_ratio != 1:
assert self.resharding_stream is not None
with npu_stream_switch(self.resharding_stream):
reshape_cache_event.wait()
rearrange_block_ids = sorted({
block_id
for request in connector_metadata.requests.values()
for block_id in request.local_block_ids
})
def sort_kv_cache(input_kv: list[list[int]]): keys = kv_layer[0][rearrange_block_ids].clone()
return torch.cat([ values = kv_layer[1][rearrange_block_ids].clone()
torch.chunk(tensor, self.pd_head_ratio, dim=0)[x] # sort kv caches for each block
for x in range(self.pd_head_ratio) keys = keys.view(keys.size(0), self.pd_head_ratio, -1,
for tensor in input_kv *keys.shape[2:]).transpose(
]) 0, 1).reshape_as(keys)
values = values.view(values.size(0), self.pd_head_ratio,
total_block_ids = [ -1, *values.shape[2:]).transpose(
request.local_block_ids 0, 1).reshape_as(values)
for request in connector_metadata.requests.values() # reshard kv cache
] keys = keys.reshape(-1, *kv_layer[0].shape[2:])
keys = [ values = values.reshape(-1, *kv_layer[1].shape[2:])
kv_layer[0][block_ids].reshape( (keys, values) = kv_alltoall_and_rearrange(
-1, *kv_layer[0].shape[2:]).clone() self.pd_head_ratio, keys, values)
for block_ids in total_block_ids
]
values = [
kv_layer[1][block_ids].reshape(
-1, *kv_layer[1].shape[2:]).clone()
for block_ids in total_block_ids
]
key_block_size = keys[0].size(0) // len(total_block_ids[0])
value_block_size = values[0].size(0) // len(total_block_ids[0])
keys = sort_kv_cache(keys) # [req1_key, req2_key]
values = sort_kv_cache(values)
(keys,
values) = kv_alltoall_and_rearrange(self.pd_head_ratio, keys,
values)
key_start_id = 0
value_start_id = 0
else: else:
key = None keys = None
value = None values = None
rearrange_block_ids = None
assert self.kv_send_layer_thread is not None
assert reshape_cache_event is not None
send_task = SendTask(wait_event=reshape_cache_event,
k_cache=keys,
v_cache=values,
layer_idx=self.current_layer,
rearrange_block_ids=rearrange_block_ids)
for req_id, req_meta in connector_metadata.requests.items(): for req_id, req_meta in connector_metadata.requests.items():
if self.pd_head_ratio != 1:
key_block_num = len(
req_meta.local_block_ids) * key_block_size
value_block_num = len(
req_meta.local_block_ids) * value_block_size
key = keys[key_start_id:key_start_id + key_block_num]
value = values[value_start_id:value_start_id +
value_block_num]
key_start_id += key_block_num
value_start_id += value_block_num
req_meta_update = self.update_decoder_info(req_id, req_meta) req_meta_update = self.update_decoder_info(req_id, req_meta)
logger.debug( logger.debug(
f"Add request {req_id} to kv send layer thread. {req_meta_update=}" f"Add request {req_id} to kv send layer thread. {req_meta_update=}"
) )
assert self.kv_send_layer_thread is not None send_task.send_request[req_id] = req_meta_update
assert reshape_cache_event is not None
self.kv_send_layer_thread.send_queue.put( self.kv_send_layer_thread.send_queue.put(send_task)
(req_id, req_meta_update, self.current_layer, key, value,
reshape_cache_event))
self.current_layer += 1 self.current_layer += 1
def _get_remote_socket( def _get_remote_socket(
@@ -1106,6 +1146,14 @@ class MooncakeLayerwiseConnectorWorker:
logger.info( logger.info(
f"Query to port and kv base addr for request {req_id} from {req_meta_update.remote_host}:{req_meta_update.remote_port} success {agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}" f"Query to port and kv base addr for request {req_id} from {req_meta_update.remote_host}:{req_meta_update.remote_port} success {agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}"
) )
session_id = f"{req_meta_update.remote_host}:{agent_meta.te_rpc_port}"
ret = self.engine.batch_transfer_sync_write(
session_id, [self.kv_caches_base_addr[0]],
[agent_meta.kv_caches_base_addr[0]], 128)
if ret < 0:
logger.error(
f"Mooncake transfer failed to create link to device {session_id}"
)
req_meta_update.remote_te_rpc_port = self.remote_te_port[ req_meta_update.remote_te_rpc_port = self.remote_te_port[
req_meta_update.remote_engine_id][req_meta_update.remote_port] req_meta_update.remote_engine_id][req_meta_update.remote_port]
req_meta_update.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[ req_meta_update.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[