[Feature] mooncake connector support GQA transport (#2947)

### What this PR does / why we need it?
The previous implementation of the Mooncake connector only supported
scenarios where the Tensor Parallel sizes for the Prefill and Decode
phases were the same for MLA and GQA/MHA.

For heterogeneous TP scenarios, a single rank on a decode node needs to
pull the KV cache from multiple ranks on the prefill nodes and then
merge them (only support prefill TP >= decode TP now). During this
merge, a transpose operation is required because the layouts of the KV
caches are different. To minimize transpose overhead, we use the
npu_paged_cache_load operation to extract the blocks corresponding to
the request from the KV cache. After performing the transpose, we use
_npu_reshape_and_cache to write the blocks back to their original
positions.

This process is illustrated in the diagram below.

b means block_size, this diagram illustrates transpose kv cache layout
for one block. In the implementation, we transpose kv cache by layer for
one request.

<img width="1464" height="916" alt="image"
src="https://github.com/user-attachments/assets/09d96a98-e41c-4733-9535-05544163081a"
/>

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested

- vLLM version: v0.11.0
---------

Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Signed-off-by: zzy-ContiLearn <1831242919@qq.com>
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: Kurumi5210 <jaychou1620@gmail.com>
Co-authored-by: zzy-ContiLearn <1831242919@qq.com>
Co-authored-by: chenxiao <cx02308786@antgroup.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: zzhx1 <zzh_201018@outlook.com>
This commit is contained in:
lidenghui1110
2025-10-13 15:48:37 +08:00
committed by GitHub
parent 847d12a389
commit 0563106477

View File

@@ -17,6 +17,7 @@ import msgspec
import numpy as np
import numpy.typing as npt
import torch
import torch_npu
import zmq
from mooncake.engine import TransferEngine # type: ignore
from vllm import envs
@@ -30,7 +31,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
@@ -127,8 +128,8 @@ class KVCacheSendingThread(threading.Thread):
def __init__(self, tp_rank: int, decode_tp_size: int, local_engine_id: str,
side_channel_host: str, side_channel_port: int,
metadata: MooncakeAgentMetadata,
ready_event: threading.Event):
metadata: MooncakeAgentMetadata, ready_event: threading.Event,
kv_caches: dict[str, Any]):
super().__init__(daemon=True, name="KVCacheSendingThread")
self.tp_rank = tp_rank
self.decode_tp_size = decode_tp_size
@@ -137,6 +138,7 @@ class KVCacheSendingThread(threading.Thread):
self.side_channel_port = side_channel_port
self.metadata = metadata
self.ready_event = ready_event
self.kv_caches = kv_caches
self.task_tracker = KVCacheTaskTracker()
@@ -220,7 +222,8 @@ class KVCacheRecvingThread(threading.Thread):
def __init__(self, tp_rank: int, tp_size: int, engine: TransferEngine,
local_engine_id: str, local_handshake_port: int,
local_kv_caches_base_addr: list[int], block_len: list[int],
ready_event: threading.Event):
ready_event: threading.Event, vllm_config: VllmConfig,
kv_caches: dict[str, Any]):
super().__init__(daemon=True, name="KVCacheRecvingThread")
self.tp_rank = tp_rank
self.tp_size = tp_size
@@ -242,7 +245,6 @@ class KVCacheRecvingThread(threading.Thread):
self.use_sfa = len(block_len) == 3
self.request_queue: queue.Queue[Any] = queue.Queue()
# TODO(jianzs): make this configurable
self.executor = ThreadPoolExecutor(max_workers=32)
self.task_tracker = KVCacheTaskTracker()
@@ -256,9 +258,15 @@ class KVCacheRecvingThread(threading.Thread):
self.remote_poller = zmq.Poller() # type: ignore
self.timeout = 1.0 # seconds
self.vllm_config = vllm_config
self.model_config = self.vllm_config.model_config
self.num_key_value_heads = self.model_config.hf_config.num_key_value_heads
self.kv_caches = kv_caches
def add_request(self, request_id: str, local_block_ids: list[int],
remote_block_ids: list[int], remote_engine_id: str,
remote_host: str, remote_handshake_port: int):
remote_host: str, remote_handshake_port: int, offset: int,
num_need_pulls: int):
"""Add a new request to the queue for processing."""
logger.debug(f"Adding request {request_id} to the queue.")
self.request_queue.put({
@@ -268,6 +276,8 @@ class KVCacheRecvingThread(threading.Thread):
"remote_engine_id": remote_engine_id,
"remote_host": remote_host,
"remote_handshake_port": remote_handshake_port,
"offset": offset,
"num_need_pulls": num_need_pulls
})
def get_and_clear_finished_requests(self) -> set[str]:
@@ -296,6 +306,8 @@ class KVCacheRecvingThread(threading.Thread):
request_id = req_meta["request_id"]
remote_host = req_meta["remote_host"]
remote_handshake_port = req_meta["remote_handshake_port"]
offset = req_meta["offset"]
num_need_pulls = req_meta["num_need_pulls"]
try:
logger.debug(
@@ -307,12 +319,13 @@ class KVCacheRecvingThread(threading.Thread):
logger.error("Failed to transfer KV cache for request "
f"{request_id}: {e}")
finally:
self.task_tracker.update_done_task_count(request_id)
# Always send the done signal to the remote host to ensure proper
# resource cleanup. Failing to do so may cause a memory leak on the
# remote host.
self._send_done_recv_signal(request_id, remote_host,
remote_handshake_port)
if offset == num_need_pulls - 1:
self.task_tracker.update_done_task_count(request_id)
self.request_queue.task_done()
def _transfer_kv_cache(self, req_meta: dict[str, Any]):
@@ -323,6 +336,8 @@ class KVCacheRecvingThread(threading.Thread):
remote_engine_id = req_meta["remote_engine_id"]
remote_host = req_meta["remote_host"]
remote_handshake_port = req_meta["remote_handshake_port"]
offset = req_meta["offset"]
self.num_need_pulls = req_meta["num_need_pulls"]
# Full prefix cache hit: do not need to read remote blocks, just notify
# P worker that we have the blocks we need.
@@ -331,23 +346,28 @@ class KVCacheRecvingThread(threading.Thread):
# Check if we have the remote metadata cached.
if remote_engine_id not in self.kv_caches_base_addr or \
remote_handshake_port not in self.kv_caches_base_addr[remote_engine_id]:
remote_handshake_port not in self.kv_caches_base_addr[remote_engine_id]:
self._get_remote_metadata(remote_host, remote_handshake_port)
grouped_remote_block_ids, grouped_local_block_ids = \
group_concurrent_contiguous(remote_block_ids, local_block_ids)
if self.num_need_pulls == 1:
grouped_remote_block_ids, grouped_local_block_ids = \
group_concurrent_contiguous(remote_block_ids, local_block_ids)
else:
remote_block_ids = list(map(lambda x: [x], remote_block_ids))
local_block_ids = list(map(lambda x: [x], local_block_ids))
grouped_remote_block_ids, grouped_local_block_ids = remote_block_ids, local_block_ids
num_transfer_groups = len(grouped_remote_block_ids)
remote_kv_caches_base_addrs = \
self.kv_caches_base_addr[remote_engine_id][remote_handshake_port]
local_kv_caches_base_addrs = \
self.kv_caches_base_addr[self.local_engine_id][self.local_handshake_port]
req_start_time = time.perf_counter()
num_transfer_groups = len(grouped_remote_block_ids)
num_blocks = len(local_block_ids)
remote_transfer_port = self.remote_te_port[remote_engine_id][
remote_handshake_port]
num_blocks = len(local_block_ids)
session_id = f"{remote_host}:{remote_transfer_port}"
req_start_time = time.perf_counter()
src_list, dst_list, length_list = [], [], []
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)):
@@ -357,14 +377,17 @@ class KVCacheRecvingThread(threading.Thread):
block_len = (self.block_len[k % 3])
else:
block_len = (self.block_len[0])
for i, remote_block_id in enumerate(grouped_remote_block_ids):
local_block_ids = grouped_local_block_ids[i]
src = src_layer_base_addr + local_block_ids[0] * block_len
dst = dst_layer_base_addr + remote_block_id[0] * block_len
length = len(local_block_ids) * block_len
inner_block_len = block_len // self.num_need_pulls
for remote_block_id, local_block_id in zip(
grouped_remote_block_ids, grouped_local_block_ids):
src = src_layer_base_addr + local_block_id[
0] * block_len + offset * inner_block_len
dst = dst_layer_base_addr + remote_block_id[0] * inner_block_len
length = inner_block_len * len(local_block_id)
src_list.append(src)
dst_list.append(dst)
length_list.append(length)
ret = self.engine.batch_transfer_sync_read(session_id, src_list,
dst_list, length_list)
if ret < 0:
@@ -376,8 +399,99 @@ class KVCacheRecvingThread(threading.Thread):
req_transfer_elapsed = (req_end_time - req_start_time) * 1000
logger.info(
"KV cache transfer for request %s took %.2f ms (%d groups,"
" %d blocks).", request_id, req_transfer_elapsed,
num_transfer_groups, num_blocks)
" %d blocks). local_ip %s local_device_id %s remote_session_id %s",
request_id, req_transfer_elapsed, num_transfer_groups, num_blocks,
get_ip(), self.tp_rank, session_id)
if self.num_need_pulls > 1 and offset == self.num_need_pulls - 1:
self._cat_kv_cache(grouped_local_block_ids)
def _cat_kv_cache(self, block_ids: list[list[int]]):
# Get necessary parameters
k_cache = list(self.kv_caches.values())[0][0]
kv_shape = k_cache.shape
dtype = k_cache.dtype
device = k_cache.device
head_dim = self.model_config.hf_config.head_dim
block_size = self.vllm_config.cache_config.block_size
num_kv_head = max(
self.model_config.hf_config.num_key_value_heads // self.tp_size, 1)
flat_block_ids = [item for sublist in block_ids for item in sublist]
block_ids_tensor = torch.tensor(flat_block_ids, dtype=torch.int32)
num_blocks = len(flat_block_ids)
block_len = num_blocks * block_size
# Create device tensors for copy operations
block_table = block_ids_tensor.view(1, -1).to(device=device)
block_len_tensor = torch.tensor([block_len],
dtype=torch.int32).to(device=device)
seq_start_tensor = torch.tensor([0],
dtype=torch.int32).to(device=device)
# Initialize buffers
k_buffer = torch.empty(block_len,
num_kv_head,
head_dim,
dtype=dtype,
device=device)
v_buffer = torch.empty(block_len,
num_kv_head,
head_dim,
dtype=dtype,
device=device)
# Create slot mapping for reshape operations
block_offsets = torch.arange(0, block_size, dtype=torch.int32)
slot_mapping = (block_offsets.reshape(
(1, block_size)) + block_ids_tensor.reshape(
(num_blocks, 1)) * block_size)
slot_mapping = slot_mapping.flatten().to(device=device)
# Process each layer in the KV cache
for _, (k_cache_layer, v_cache_layer) in self.kv_caches.items():
if len(
k_cache_layer.shape
) == 3: # kv shape in torchair model is [num_block, block_size, num_kv_head*head_dim]
k_cache_layer = k_cache_layer.view(kv_shape[0], kv_shape[1],
num_kv_head, head_dim)
v_cache_layer = v_cache_layer.view(kv_shape[0], kv_shape[1],
num_kv_head, head_dim)
# Load cache data into buffers
torch_npu.atb.npu_paged_cache_load(
k_cache_layer,
v_cache_layer,
block_table,
block_len_tensor,
seq_starts=seq_start_tensor,
key=k_buffer,
value=v_buffer,
)
# Transpose KV cache
k_buffer = self._transpose_kv_cache_between_head(
k_buffer, num_blocks, block_size, block_len, num_kv_head)
v_buffer = self._transpose_kv_cache_between_head(
v_buffer, num_blocks, block_size, block_len, num_kv_head)
# Reshape and cache the processed buffers
torch_npu._npu_reshape_and_cache(
key=k_buffer,
value=v_buffer,
key_cache=k_cache_layer,
value_cache=v_cache_layer,
slot_indices=slot_mapping,
)
# Clean up buffers
del k_buffer, v_buffer
def _transpose_kv_cache_between_head(self, buffer: torch.Tensor,
num_blocks: int, block_size: int,
block_len: int,
num_kv_head: int) -> torch.Tensor:
buffer = buffer.view(num_blocks, self.num_need_pulls, block_size, -1)
buffer.transpose_(1, 2)
return buffer.contiguous().view(block_len, num_kv_head, -1)
def _get_remote_metadata(self, remote_host: str,
remote_handshake_port: int) -> None:
@@ -573,9 +687,11 @@ class MooncakeConnectorScheduler:
def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.vllm_config = vllm_config
init_ascend_config(vllm_config)
self.ascend_config = get_ascend_config()
self.block_size = vllm_config.cache_config.block_size
self.engine_id = engine_id
self.local_ip = get_ip()
logger.info("Initializing Mooncake Scheduler %s", engine_id)
self.side_channel_host = get_ip()
@@ -716,6 +832,7 @@ class MooncakeConnectorScheduler:
remote_engine_id=self.engine_id,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
last_token_id=request.output_token_ids[-1],
)
def get_finished_count(self) -> Optional[int]:
@@ -732,12 +849,23 @@ class MooncakeConnectorScheduler:
"decode", {})
assert "tp_size" in decode_parallel_config.keys()
self._decode_tp_size = decode_parallel_config["tp_size"]
num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads
if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa:
return self._decode_tp_size
num_need_pulls = 1
else:
# TODO support mha and gqa
return None
num_p_block_heads = max(
1, num_key_value_heads // self._prefill_tp_size)
num_d_block_heads = max(
1, num_key_value_heads // self._decode_tp_size)
num_need_pulls = num_d_block_heads // num_p_block_heads
kv_role = self.vllm_config.kv_transfer_config.kv_role
logger.debug(
"get_finished_count, kv_role=%s, num_need_pulls=%d, decode_tp_size=%d",
kv_role, num_need_pulls, self._decode_tp_size)
if kv_role == 'kv_producer':
return num_need_pulls * self._decode_tp_size
else:
return self._decode_tp_size
class MooncakeConnectorWorker:
@@ -757,6 +885,7 @@ class MooncakeConnectorWorker:
# Metadata.
self.vllm_config = vllm_config
self.ascend_config = get_ascend_config()
self.engine_id = engine_id
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
@@ -767,6 +896,7 @@ class MooncakeConnectorWorker:
self.side_channel_host = get_ip()
self.max_device_id = self.tp_size * self.dp_size
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads
# Handshake base port
self.side_channel_port = (
@@ -809,8 +939,17 @@ class MooncakeConnectorWorker:
self.kv_send_thread: Optional[KVCacheSendingThread] = None
self.kv_recv_thread: Optional[KVCacheRecvingThread] = None
# kv_transfer variables
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
if self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa:
self.num_need_pulls = 1
else:
num_d_block_heads = max(1,
self.num_key_value_heads // self.tp_size)
num_p_block_heads = max(
1, self.num_key_value_heads // self._prefill_tp_size)
self.num_need_pulls = num_d_block_heads // num_p_block_heads
def _get_prefill_decode_size(self, vllm_config: VllmConfig):
# get prefill tp and dp size from extra config
@@ -886,15 +1025,17 @@ class MooncakeConnectorWorker:
self.num_blocks, block_shape_norm, block_shape_pe,
block_shape_k)
else:
# [num_block, block_size, num_head, hidden_dim]
# eager:[num_block, block_size, num_head, hidden_dim]
# torchair:[num_block, block_size, num_head*hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
kv_elem_size = first_kv_cache.element_size()
block_rank = 3 # [block_size, kv_heads, head_dim]
block_rank = len(
first_kv_cache.shape
) - 1 # [block_size, kv_heads, head_dim] or [block_size, kv_heads*head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
self.block_len = [kv_elem_size * math.prod(block_shape)]
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
block_shape)
logger.info(
"Registering KV_Caches. use_mla: %s, use_sfa: %s, shape %s",
self.use_mla, self.use_sfa, first_kv_cache.shape)
@@ -935,23 +1076,21 @@ class MooncakeConnectorWorker:
ready_event = threading.Event()
if self.kv_role == 'kv_producer':
self.kv_send_thread = KVCacheSendingThread(self.tp_rank,
self._decode_tp_size,
self.engine_id,
self.side_channel_host,
self.side_channel_port,
metadata, ready_event)
self.kv_send_thread = KVCacheSendingThread(
self.tp_rank, self._decode_tp_size, self.engine_id,
self.side_channel_host, self.side_channel_port, metadata,
ready_event, self.kv_caches)
self.kv_send_thread.start()
else:
self.kv_recv_thread = KVCacheRecvingThread(
self.tp_rank, self.tp_size, self.engine, self.engine_id,
self.handshake_port, kv_caches_base_addr, self.block_len,
ready_event)
ready_event, self.vllm_config, self.kv_caches)
self.kv_recv_thread.start()
ready_event.wait()
def _register(self, ptr, length):
logger.info(
logger.debug(
"Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, "
"block_lens=%s", ptr, length, self.num_blocks, self.block_len)
ret_value = self.engine.register_memory(ptr, length)
@@ -982,16 +1121,21 @@ class MooncakeConnectorWorker:
meta.remote_engine_id, len(meta.local_block_ids),
len(meta.remote_block_ids))
remote_handshake_port = meta.remote_port + \
self._get_remote_tp_rank(req_id)
self.kv_recv_thread.add_request( # type: ignore[union-attr]
request_id=req_id,
local_block_ids=meta.local_block_ids,
remote_block_ids=meta.remote_block_ids,
remote_engine_id=meta.remote_engine_id,
remote_host=meta.remote_host,
remote_handshake_port=remote_handshake_port,
)
choosen_rank_list = self._get_remote_tp_rank(req_id)
remote_handshake_port_list = [
x + meta.remote_port for x in choosen_rank_list
]
for i in range(self.num_need_pulls):
assert self.kv_recv_thread is not None
self.kv_recv_thread.add_request(
request_id=req_id,
local_block_ids=meta.local_block_ids,
remote_block_ids=meta.remote_block_ids,
remote_engine_id=meta.remote_engine_id,
remote_host=meta.remote_host,
remote_handshake_port=remote_handshake_port_list[i],
offset=i,
num_need_pulls=self.num_need_pulls)
if self.kv_send_thread is not None:
for req_id, delay_start_time in metadata.requests_to_send.items():
@@ -999,17 +1143,43 @@ class MooncakeConnectorWorker:
self.kv_send_thread.add_delayed_request(
req_id, delay_start_time)
def _get_remote_tp_rank(self, req_id: str) -> int:
def _get_remote_tp_rank(self, req_id: str) -> List[int]:
return self._get_remote_tp_ranks_for_req(req_id)[self.tp_rank]
def _get_remote_tp_ranks_for_req(self, req_id: str) -> list[int]:
def _get_remote_tp_ranks_for_req(self, req_id: str) -> List[List[int]]:
if self._prefill_tp_size == self._decode_tp_size:
return list(range(self._prefill_tp_size))
result = list(map(lambda x: [x], range(self._prefill_tp_size)))
return result
seed = string_to_int64_hash(req_id)
rand = random.Random(seed)
sampled_nums = rand.sample(range(self._prefill_tp_size),
self._decode_tp_size)
sampled_nums = []
ori_data = np.arange(self._prefill_tp_size)
# random split prefill tp list
if self._prefill_tp_size > self.num_key_value_heads or self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa:
# use deepseek mla, num_key_value_heads == 128, but consider as 1
if self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa:
num_kv_head = 1
else:
num_kv_head = self.num_key_value_heads
num_groups = len(ori_data) // num_kv_head
ori_data = ori_data.reshape(-1, num_groups)
rand_group_index = rand.sample(range(num_groups), \
max(self._decode_tp_size // num_kv_head, 1)) # random choose a group
choosen_group = ori_data[:, [rand_group_index]]
flattened = choosen_group.reshape(-1).tolist()
sampled_nums = [
flattened[i:i + self.num_need_pulls]
for i in range(0, len(flattened), self.num_need_pulls)
]
# non-random split
else:
group_size = self._prefill_tp_size // self._decode_tp_size
for i in range(self._decode_tp_size):
slice = ori_data[i * group_size:(i + 1) * group_size]
sampled_nums.append(slice)
return sampled_nums