From 05631064779462ee96da8ed4d113527d1eeb3324 Mon Sep 17 00:00:00 2001
From: lidenghui1110 <30521952+lidenghui1110@users.noreply.github.com>
Date: Mon, 13 Oct 2025 15:48:37 +0800
Subject: [PATCH] [Feature] mooncake connector support GQA transport (#2947)
### What this PR does / why we need it?
The previous implementation of the Mooncake connector only supported
scenarios where the Tensor Parallel sizes for the Prefill and Decode
phases were the same for MLA and GQA/MHA.
For heterogeneous TP scenarios, a single rank on a decode node needs to
pull the KV cache from multiple ranks on the prefill nodes and then
merge them (only support prefill TP >= decode TP now). During this
merge, a transpose operation is required because the layouts of the KV
caches are different. To minimize transpose overhead, we use the
npu_paged_cache_load operation to extract the blocks corresponding to
the request from the KV cache. After performing the transpose, we use
_npu_reshape_and_cache to write the blocks back to their original
positions.
This process is illustrated in the diagram below.
b means block_size, this diagram illustrates transpose kv cache layout
for one block. In the implementation, we transpose kv cache by layer for
one request.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested
- vLLM version: v0.11.0
---------
Signed-off-by: chenxiao
Signed-off-by: zzy-ContiLearn <1831242919@qq.com>
Signed-off-by: zzhx1
Signed-off-by: Kurumi5210
Co-authored-by: zzy-ContiLearn <1831242919@qq.com>
Co-authored-by: chenxiao
Co-authored-by: chenxiao
Co-authored-by: zzhx1
---
vllm_ascend/distributed/mooncake_connector.py | 274 ++++++++++++++----
1 file changed, 222 insertions(+), 52 deletions(-)
diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py
index 6ecf8e7..dcdfdf6 100644
--- a/vllm_ascend/distributed/mooncake_connector.py
+++ b/vllm_ascend/distributed/mooncake_connector.py
@@ -17,6 +17,7 @@ import msgspec
import numpy as np
import numpy.typing as npt
import torch
+import torch_npu
import zmq
from mooncake.engine import TransferEngine # type: ignore
from vllm import envs
@@ -30,7 +31,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
import vllm_ascend.envs as envs_ascend
-from vllm_ascend.ascend_config import get_ascend_config
+from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
@@ -127,8 +128,8 @@ class KVCacheSendingThread(threading.Thread):
def __init__(self, tp_rank: int, decode_tp_size: int, local_engine_id: str,
side_channel_host: str, side_channel_port: int,
- metadata: MooncakeAgentMetadata,
- ready_event: threading.Event):
+ metadata: MooncakeAgentMetadata, ready_event: threading.Event,
+ kv_caches: dict[str, Any]):
super().__init__(daemon=True, name="KVCacheSendingThread")
self.tp_rank = tp_rank
self.decode_tp_size = decode_tp_size
@@ -137,6 +138,7 @@ class KVCacheSendingThread(threading.Thread):
self.side_channel_port = side_channel_port
self.metadata = metadata
self.ready_event = ready_event
+ self.kv_caches = kv_caches
self.task_tracker = KVCacheTaskTracker()
@@ -220,7 +222,8 @@ class KVCacheRecvingThread(threading.Thread):
def __init__(self, tp_rank: int, tp_size: int, engine: TransferEngine,
local_engine_id: str, local_handshake_port: int,
local_kv_caches_base_addr: list[int], block_len: list[int],
- ready_event: threading.Event):
+ ready_event: threading.Event, vllm_config: VllmConfig,
+ kv_caches: dict[str, Any]):
super().__init__(daemon=True, name="KVCacheRecvingThread")
self.tp_rank = tp_rank
self.tp_size = tp_size
@@ -242,7 +245,6 @@ class KVCacheRecvingThread(threading.Thread):
self.use_sfa = len(block_len) == 3
self.request_queue: queue.Queue[Any] = queue.Queue()
- # TODO(jianzs): make this configurable
self.executor = ThreadPoolExecutor(max_workers=32)
self.task_tracker = KVCacheTaskTracker()
@@ -256,9 +258,15 @@ class KVCacheRecvingThread(threading.Thread):
self.remote_poller = zmq.Poller() # type: ignore
self.timeout = 1.0 # seconds
+ self.vllm_config = vllm_config
+ self.model_config = self.vllm_config.model_config
+ self.num_key_value_heads = self.model_config.hf_config.num_key_value_heads
+ self.kv_caches = kv_caches
+
def add_request(self, request_id: str, local_block_ids: list[int],
remote_block_ids: list[int], remote_engine_id: str,
- remote_host: str, remote_handshake_port: int):
+ remote_host: str, remote_handshake_port: int, offset: int,
+ num_need_pulls: int):
"""Add a new request to the queue for processing."""
logger.debug(f"Adding request {request_id} to the queue.")
self.request_queue.put({
@@ -268,6 +276,8 @@ class KVCacheRecvingThread(threading.Thread):
"remote_engine_id": remote_engine_id,
"remote_host": remote_host,
"remote_handshake_port": remote_handshake_port,
+ "offset": offset,
+ "num_need_pulls": num_need_pulls
})
def get_and_clear_finished_requests(self) -> set[str]:
@@ -296,6 +306,8 @@ class KVCacheRecvingThread(threading.Thread):
request_id = req_meta["request_id"]
remote_host = req_meta["remote_host"]
remote_handshake_port = req_meta["remote_handshake_port"]
+ offset = req_meta["offset"]
+ num_need_pulls = req_meta["num_need_pulls"]
try:
logger.debug(
@@ -307,12 +319,13 @@ class KVCacheRecvingThread(threading.Thread):
logger.error("Failed to transfer KV cache for request "
f"{request_id}: {e}")
finally:
- self.task_tracker.update_done_task_count(request_id)
# Always send the done signal to the remote host to ensure proper
# resource cleanup. Failing to do so may cause a memory leak on the
# remote host.
self._send_done_recv_signal(request_id, remote_host,
remote_handshake_port)
+ if offset == num_need_pulls - 1:
+ self.task_tracker.update_done_task_count(request_id)
self.request_queue.task_done()
def _transfer_kv_cache(self, req_meta: dict[str, Any]):
@@ -323,6 +336,8 @@ class KVCacheRecvingThread(threading.Thread):
remote_engine_id = req_meta["remote_engine_id"]
remote_host = req_meta["remote_host"]
remote_handshake_port = req_meta["remote_handshake_port"]
+ offset = req_meta["offset"]
+ self.num_need_pulls = req_meta["num_need_pulls"]
# Full prefix cache hit: do not need to read remote blocks, just notify
# P worker that we have the blocks we need.
@@ -331,23 +346,28 @@ class KVCacheRecvingThread(threading.Thread):
# Check if we have the remote metadata cached.
if remote_engine_id not in self.kv_caches_base_addr or \
- remote_handshake_port not in self.kv_caches_base_addr[remote_engine_id]:
+ remote_handshake_port not in self.kv_caches_base_addr[remote_engine_id]:
self._get_remote_metadata(remote_host, remote_handshake_port)
- grouped_remote_block_ids, grouped_local_block_ids = \
- group_concurrent_contiguous(remote_block_ids, local_block_ids)
+ if self.num_need_pulls == 1:
+ grouped_remote_block_ids, grouped_local_block_ids = \
+ group_concurrent_contiguous(remote_block_ids, local_block_ids)
+ else:
+ remote_block_ids = list(map(lambda x: [x], remote_block_ids))
+ local_block_ids = list(map(lambda x: [x], local_block_ids))
+ grouped_remote_block_ids, grouped_local_block_ids = remote_block_ids, local_block_ids
+ num_transfer_groups = len(grouped_remote_block_ids)
+
remote_kv_caches_base_addrs = \
self.kv_caches_base_addr[remote_engine_id][remote_handshake_port]
local_kv_caches_base_addrs = \
self.kv_caches_base_addr[self.local_engine_id][self.local_handshake_port]
-
- req_start_time = time.perf_counter()
- num_transfer_groups = len(grouped_remote_block_ids)
- num_blocks = len(local_block_ids)
-
remote_transfer_port = self.remote_te_port[remote_engine_id][
remote_handshake_port]
+ num_blocks = len(local_block_ids)
session_id = f"{remote_host}:{remote_transfer_port}"
+
+ req_start_time = time.perf_counter()
src_list, dst_list, length_list = [], [], []
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)):
@@ -357,14 +377,17 @@ class KVCacheRecvingThread(threading.Thread):
block_len = (self.block_len[k % 3])
else:
block_len = (self.block_len[0])
- for i, remote_block_id in enumerate(grouped_remote_block_ids):
- local_block_ids = grouped_local_block_ids[i]
- src = src_layer_base_addr + local_block_ids[0] * block_len
- dst = dst_layer_base_addr + remote_block_id[0] * block_len
- length = len(local_block_ids) * block_len
+ inner_block_len = block_len // self.num_need_pulls
+ for remote_block_id, local_block_id in zip(
+ grouped_remote_block_ids, grouped_local_block_ids):
+ src = src_layer_base_addr + local_block_id[
+ 0] * block_len + offset * inner_block_len
+ dst = dst_layer_base_addr + remote_block_id[0] * inner_block_len
+ length = inner_block_len * len(local_block_id)
src_list.append(src)
dst_list.append(dst)
length_list.append(length)
+
ret = self.engine.batch_transfer_sync_read(session_id, src_list,
dst_list, length_list)
if ret < 0:
@@ -376,8 +399,99 @@ class KVCacheRecvingThread(threading.Thread):
req_transfer_elapsed = (req_end_time - req_start_time) * 1000
logger.info(
"KV cache transfer for request %s took %.2f ms (%d groups,"
- " %d blocks).", request_id, req_transfer_elapsed,
- num_transfer_groups, num_blocks)
+ " %d blocks). local_ip %s local_device_id %s remote_session_id %s",
+ request_id, req_transfer_elapsed, num_transfer_groups, num_blocks,
+ get_ip(), self.tp_rank, session_id)
+ if self.num_need_pulls > 1 and offset == self.num_need_pulls - 1:
+ self._cat_kv_cache(grouped_local_block_ids)
+
+ def _cat_kv_cache(self, block_ids: list[list[int]]):
+ # Get necessary parameters
+ k_cache = list(self.kv_caches.values())[0][0]
+ kv_shape = k_cache.shape
+ dtype = k_cache.dtype
+ device = k_cache.device
+ head_dim = self.model_config.hf_config.head_dim
+ block_size = self.vllm_config.cache_config.block_size
+ num_kv_head = max(
+ self.model_config.hf_config.num_key_value_heads // self.tp_size, 1)
+
+ flat_block_ids = [item for sublist in block_ids for item in sublist]
+ block_ids_tensor = torch.tensor(flat_block_ids, dtype=torch.int32)
+ num_blocks = len(flat_block_ids)
+ block_len = num_blocks * block_size
+
+ # Create device tensors for copy operations
+ block_table = block_ids_tensor.view(1, -1).to(device=device)
+ block_len_tensor = torch.tensor([block_len],
+ dtype=torch.int32).to(device=device)
+ seq_start_tensor = torch.tensor([0],
+ dtype=torch.int32).to(device=device)
+
+ # Initialize buffers
+ k_buffer = torch.empty(block_len,
+ num_kv_head,
+ head_dim,
+ dtype=dtype,
+ device=device)
+ v_buffer = torch.empty(block_len,
+ num_kv_head,
+ head_dim,
+ dtype=dtype,
+ device=device)
+
+ # Create slot mapping for reshape operations
+ block_offsets = torch.arange(0, block_size, dtype=torch.int32)
+ slot_mapping = (block_offsets.reshape(
+ (1, block_size)) + block_ids_tensor.reshape(
+ (num_blocks, 1)) * block_size)
+ slot_mapping = slot_mapping.flatten().to(device=device)
+
+ # Process each layer in the KV cache
+ for _, (k_cache_layer, v_cache_layer) in self.kv_caches.items():
+ if len(
+ k_cache_layer.shape
+ ) == 3: # kv shape in torchair model is [num_block, block_size, num_kv_head*head_dim]
+ k_cache_layer = k_cache_layer.view(kv_shape[0], kv_shape[1],
+ num_kv_head, head_dim)
+ v_cache_layer = v_cache_layer.view(kv_shape[0], kv_shape[1],
+ num_kv_head, head_dim)
+ # Load cache data into buffers
+ torch_npu.atb.npu_paged_cache_load(
+ k_cache_layer,
+ v_cache_layer,
+ block_table,
+ block_len_tensor,
+ seq_starts=seq_start_tensor,
+ key=k_buffer,
+ value=v_buffer,
+ )
+
+ # Transpose KV cache
+ k_buffer = self._transpose_kv_cache_between_head(
+ k_buffer, num_blocks, block_size, block_len, num_kv_head)
+ v_buffer = self._transpose_kv_cache_between_head(
+ v_buffer, num_blocks, block_size, block_len, num_kv_head)
+
+ # Reshape and cache the processed buffers
+ torch_npu._npu_reshape_and_cache(
+ key=k_buffer,
+ value=v_buffer,
+ key_cache=k_cache_layer,
+ value_cache=v_cache_layer,
+ slot_indices=slot_mapping,
+ )
+
+ # Clean up buffers
+ del k_buffer, v_buffer
+
+ def _transpose_kv_cache_between_head(self, buffer: torch.Tensor,
+ num_blocks: int, block_size: int,
+ block_len: int,
+ num_kv_head: int) -> torch.Tensor:
+ buffer = buffer.view(num_blocks, self.num_need_pulls, block_size, -1)
+ buffer.transpose_(1, 2)
+ return buffer.contiguous().view(block_len, num_kv_head, -1)
def _get_remote_metadata(self, remote_host: str,
remote_handshake_port: int) -> None:
@@ -573,9 +687,11 @@ class MooncakeConnectorScheduler:
def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.vllm_config = vllm_config
+ init_ascend_config(vllm_config)
self.ascend_config = get_ascend_config()
self.block_size = vllm_config.cache_config.block_size
self.engine_id = engine_id
+ self.local_ip = get_ip()
logger.info("Initializing Mooncake Scheduler %s", engine_id)
self.side_channel_host = get_ip()
@@ -716,6 +832,7 @@ class MooncakeConnectorScheduler:
remote_engine_id=self.engine_id,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
+ last_token_id=request.output_token_ids[-1],
)
def get_finished_count(self) -> Optional[int]:
@@ -732,12 +849,23 @@ class MooncakeConnectorScheduler:
"decode", {})
assert "tp_size" in decode_parallel_config.keys()
self._decode_tp_size = decode_parallel_config["tp_size"]
-
+ num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads
if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa:
- return self._decode_tp_size
+ num_need_pulls = 1
else:
- # TODO support mha and gqa
- return None
+ num_p_block_heads = max(
+ 1, num_key_value_heads // self._prefill_tp_size)
+ num_d_block_heads = max(
+ 1, num_key_value_heads // self._decode_tp_size)
+ num_need_pulls = num_d_block_heads // num_p_block_heads
+ kv_role = self.vllm_config.kv_transfer_config.kv_role
+ logger.debug(
+ "get_finished_count, kv_role=%s, num_need_pulls=%d, decode_tp_size=%d",
+ kv_role, num_need_pulls, self._decode_tp_size)
+ if kv_role == 'kv_producer':
+ return num_need_pulls * self._decode_tp_size
+ else:
+ return self._decode_tp_size
class MooncakeConnectorWorker:
@@ -757,6 +885,7 @@ class MooncakeConnectorWorker:
# Metadata.
self.vllm_config = vllm_config
+ self.ascend_config = get_ascend_config()
self.engine_id = engine_id
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
@@ -767,6 +896,7 @@ class MooncakeConnectorWorker:
self.side_channel_host = get_ip()
self.max_device_id = self.tp_size * self.dp_size
self.kv_role = vllm_config.kv_transfer_config.kv_role
+ self.num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads
# Handshake base port
self.side_channel_port = (
@@ -809,8 +939,17 @@ class MooncakeConnectorWorker:
self.kv_send_thread: Optional[KVCacheSendingThread] = None
self.kv_recv_thread: Optional[KVCacheRecvingThread] = None
+ # kv_transfer variables
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
+ if self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa:
+ self.num_need_pulls = 1
+ else:
+ num_d_block_heads = max(1,
+ self.num_key_value_heads // self.tp_size)
+ num_p_block_heads = max(
+ 1, self.num_key_value_heads // self._prefill_tp_size)
+ self.num_need_pulls = num_d_block_heads // num_p_block_heads
def _get_prefill_decode_size(self, vllm_config: VllmConfig):
# get prefill tp and dp size from extra config
@@ -886,15 +1025,17 @@ class MooncakeConnectorWorker:
self.num_blocks, block_shape_norm, block_shape_pe,
block_shape_k)
else:
- # [num_block, block_size, num_head, hidden_dim]
+ # eager:[num_block, block_size, num_head, hidden_dim]
+ # torchair:[num_block, block_size, num_head*hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
kv_elem_size = first_kv_cache.element_size()
- block_rank = 3 # [block_size, kv_heads, head_dim]
+ block_rank = len(
+ first_kv_cache.shape
+ ) - 1 # [block_size, kv_heads, head_dim] or [block_size, kv_heads*head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
self.block_len = [kv_elem_size * math.prod(block_shape)]
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
block_shape)
-
logger.info(
"Registering KV_Caches. use_mla: %s, use_sfa: %s, shape %s",
self.use_mla, self.use_sfa, first_kv_cache.shape)
@@ -935,23 +1076,21 @@ class MooncakeConnectorWorker:
ready_event = threading.Event()
if self.kv_role == 'kv_producer':
- self.kv_send_thread = KVCacheSendingThread(self.tp_rank,
- self._decode_tp_size,
- self.engine_id,
- self.side_channel_host,
- self.side_channel_port,
- metadata, ready_event)
+ self.kv_send_thread = KVCacheSendingThread(
+ self.tp_rank, self._decode_tp_size, self.engine_id,
+ self.side_channel_host, self.side_channel_port, metadata,
+ ready_event, self.kv_caches)
self.kv_send_thread.start()
else:
self.kv_recv_thread = KVCacheRecvingThread(
self.tp_rank, self.tp_size, self.engine, self.engine_id,
self.handshake_port, kv_caches_base_addr, self.block_len,
- ready_event)
+ ready_event, self.vllm_config, self.kv_caches)
self.kv_recv_thread.start()
ready_event.wait()
def _register(self, ptr, length):
- logger.info(
+ logger.debug(
"Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, "
"block_lens=%s", ptr, length, self.num_blocks, self.block_len)
ret_value = self.engine.register_memory(ptr, length)
@@ -982,16 +1121,21 @@ class MooncakeConnectorWorker:
meta.remote_engine_id, len(meta.local_block_ids),
len(meta.remote_block_ids))
- remote_handshake_port = meta.remote_port + \
- self._get_remote_tp_rank(req_id)
- self.kv_recv_thread.add_request( # type: ignore[union-attr]
- request_id=req_id,
- local_block_ids=meta.local_block_ids,
- remote_block_ids=meta.remote_block_ids,
- remote_engine_id=meta.remote_engine_id,
- remote_host=meta.remote_host,
- remote_handshake_port=remote_handshake_port,
- )
+ choosen_rank_list = self._get_remote_tp_rank(req_id)
+ remote_handshake_port_list = [
+ x + meta.remote_port for x in choosen_rank_list
+ ]
+ for i in range(self.num_need_pulls):
+ assert self.kv_recv_thread is not None
+ self.kv_recv_thread.add_request(
+ request_id=req_id,
+ local_block_ids=meta.local_block_ids,
+ remote_block_ids=meta.remote_block_ids,
+ remote_engine_id=meta.remote_engine_id,
+ remote_host=meta.remote_host,
+ remote_handshake_port=remote_handshake_port_list[i],
+ offset=i,
+ num_need_pulls=self.num_need_pulls)
if self.kv_send_thread is not None:
for req_id, delay_start_time in metadata.requests_to_send.items():
@@ -999,17 +1143,43 @@ class MooncakeConnectorWorker:
self.kv_send_thread.add_delayed_request(
req_id, delay_start_time)
- def _get_remote_tp_rank(self, req_id: str) -> int:
+ def _get_remote_tp_rank(self, req_id: str) -> List[int]:
return self._get_remote_tp_ranks_for_req(req_id)[self.tp_rank]
- def _get_remote_tp_ranks_for_req(self, req_id: str) -> list[int]:
+ def _get_remote_tp_ranks_for_req(self, req_id: str) -> List[List[int]]:
if self._prefill_tp_size == self._decode_tp_size:
- return list(range(self._prefill_tp_size))
+ result = list(map(lambda x: [x], range(self._prefill_tp_size)))
+ return result
seed = string_to_int64_hash(req_id)
rand = random.Random(seed)
- sampled_nums = rand.sample(range(self._prefill_tp_size),
- self._decode_tp_size)
+ sampled_nums = []
+ ori_data = np.arange(self._prefill_tp_size)
+ # random split prefill tp list
+ if self._prefill_tp_size > self.num_key_value_heads or self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa:
+ # use deepseek mla, num_key_value_heads == 128, but consider as 1
+ if self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa:
+ num_kv_head = 1
+ else:
+ num_kv_head = self.num_key_value_heads
+ num_groups = len(ori_data) // num_kv_head
+ ori_data = ori_data.reshape(-1, num_groups)
+ rand_group_index = rand.sample(range(num_groups), \
+ max(self._decode_tp_size // num_kv_head, 1)) # random choose a group
+
+ choosen_group = ori_data[:, [rand_group_index]]
+ flattened = choosen_group.reshape(-1).tolist()
+ sampled_nums = [
+ flattened[i:i + self.num_need_pulls]
+ for i in range(0, len(flattened), self.num_need_pulls)
+ ]
+
+ # non-random split
+ else:
+ group_size = self._prefill_tp_size // self._decode_tp_size
+ for i in range(self._decode_tp_size):
+ slice = ori_data[i * group_size:(i + 1) * group_size]
+ sampled_nums.append(slice)
return sampled_nums