From 05631064779462ee96da8ed4d113527d1eeb3324 Mon Sep 17 00:00:00 2001 From: lidenghui1110 <30521952+lidenghui1110@users.noreply.github.com> Date: Mon, 13 Oct 2025 15:48:37 +0800 Subject: [PATCH] [Feature] mooncake connector support GQA transport (#2947) ### What this PR does / why we need it? The previous implementation of the Mooncake connector only supported scenarios where the Tensor Parallel sizes for the Prefill and Decode phases were the same for MLA and GQA/MHA. For heterogeneous TP scenarios, a single rank on a decode node needs to pull the KV cache from multiple ranks on the prefill nodes and then merge them (only support prefill TP >= decode TP now). During this merge, a transpose operation is required because the layouts of the KV caches are different. To minimize transpose overhead, we use the npu_paged_cache_load operation to extract the blocks corresponding to the request from the KV cache. After performing the transpose, we use _npu_reshape_and_cache to write the blocks back to their original positions. This process is illustrated in the diagram below. b means block_size, this diagram illustrates transpose kv cache layout for one block. In the implementation, we transpose kv cache by layer for one request. image ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested - vLLM version: v0.11.0 --------- Signed-off-by: chenxiao Signed-off-by: zzy-ContiLearn <1831242919@qq.com> Signed-off-by: zzhx1 Signed-off-by: Kurumi5210 Co-authored-by: zzy-ContiLearn <1831242919@qq.com> Co-authored-by: chenxiao Co-authored-by: chenxiao Co-authored-by: zzhx1 --- vllm_ascend/distributed/mooncake_connector.py | 274 ++++++++++++++---- 1 file changed, 222 insertions(+), 52 deletions(-) diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 6ecf8e7..dcdfdf6 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -17,6 +17,7 @@ import msgspec import numpy as np import numpy.typing as npt import torch +import torch_npu import zmq from mooncake.engine import TransferEngine # type: ignore from vllm import envs @@ -30,7 +31,7 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import RequestStatus import vllm_ascend.envs as envs_ascend -from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -127,8 +128,8 @@ class KVCacheSendingThread(threading.Thread): def __init__(self, tp_rank: int, decode_tp_size: int, local_engine_id: str, side_channel_host: str, side_channel_port: int, - metadata: MooncakeAgentMetadata, - ready_event: threading.Event): + metadata: MooncakeAgentMetadata, ready_event: threading.Event, + kv_caches: dict[str, Any]): super().__init__(daemon=True, name="KVCacheSendingThread") self.tp_rank = tp_rank self.decode_tp_size = decode_tp_size @@ -137,6 +138,7 @@ class KVCacheSendingThread(threading.Thread): self.side_channel_port = side_channel_port self.metadata = metadata self.ready_event = ready_event + self.kv_caches = kv_caches self.task_tracker = KVCacheTaskTracker() @@ -220,7 +222,8 @@ class KVCacheRecvingThread(threading.Thread): def __init__(self, tp_rank: int, tp_size: int, engine: TransferEngine, local_engine_id: str, local_handshake_port: int, local_kv_caches_base_addr: list[int], block_len: list[int], - ready_event: threading.Event): + ready_event: threading.Event, vllm_config: VllmConfig, + kv_caches: dict[str, Any]): super().__init__(daemon=True, name="KVCacheRecvingThread") self.tp_rank = tp_rank self.tp_size = tp_size @@ -242,7 +245,6 @@ class KVCacheRecvingThread(threading.Thread): self.use_sfa = len(block_len) == 3 self.request_queue: queue.Queue[Any] = queue.Queue() - # TODO(jianzs): make this configurable self.executor = ThreadPoolExecutor(max_workers=32) self.task_tracker = KVCacheTaskTracker() @@ -256,9 +258,15 @@ class KVCacheRecvingThread(threading.Thread): self.remote_poller = zmq.Poller() # type: ignore self.timeout = 1.0 # seconds + self.vllm_config = vllm_config + self.model_config = self.vllm_config.model_config + self.num_key_value_heads = self.model_config.hf_config.num_key_value_heads + self.kv_caches = kv_caches + def add_request(self, request_id: str, local_block_ids: list[int], remote_block_ids: list[int], remote_engine_id: str, - remote_host: str, remote_handshake_port: int): + remote_host: str, remote_handshake_port: int, offset: int, + num_need_pulls: int): """Add a new request to the queue for processing.""" logger.debug(f"Adding request {request_id} to the queue.") self.request_queue.put({ @@ -268,6 +276,8 @@ class KVCacheRecvingThread(threading.Thread): "remote_engine_id": remote_engine_id, "remote_host": remote_host, "remote_handshake_port": remote_handshake_port, + "offset": offset, + "num_need_pulls": num_need_pulls }) def get_and_clear_finished_requests(self) -> set[str]: @@ -296,6 +306,8 @@ class KVCacheRecvingThread(threading.Thread): request_id = req_meta["request_id"] remote_host = req_meta["remote_host"] remote_handshake_port = req_meta["remote_handshake_port"] + offset = req_meta["offset"] + num_need_pulls = req_meta["num_need_pulls"] try: logger.debug( @@ -307,12 +319,13 @@ class KVCacheRecvingThread(threading.Thread): logger.error("Failed to transfer KV cache for request " f"{request_id}: {e}") finally: - self.task_tracker.update_done_task_count(request_id) # Always send the done signal to the remote host to ensure proper # resource cleanup. Failing to do so may cause a memory leak on the # remote host. self._send_done_recv_signal(request_id, remote_host, remote_handshake_port) + if offset == num_need_pulls - 1: + self.task_tracker.update_done_task_count(request_id) self.request_queue.task_done() def _transfer_kv_cache(self, req_meta: dict[str, Any]): @@ -323,6 +336,8 @@ class KVCacheRecvingThread(threading.Thread): remote_engine_id = req_meta["remote_engine_id"] remote_host = req_meta["remote_host"] remote_handshake_port = req_meta["remote_handshake_port"] + offset = req_meta["offset"] + self.num_need_pulls = req_meta["num_need_pulls"] # Full prefix cache hit: do not need to read remote blocks, just notify # P worker that we have the blocks we need. @@ -331,23 +346,28 @@ class KVCacheRecvingThread(threading.Thread): # Check if we have the remote metadata cached. if remote_engine_id not in self.kv_caches_base_addr or \ - remote_handshake_port not in self.kv_caches_base_addr[remote_engine_id]: + remote_handshake_port not in self.kv_caches_base_addr[remote_engine_id]: self._get_remote_metadata(remote_host, remote_handshake_port) - grouped_remote_block_ids, grouped_local_block_ids = \ - group_concurrent_contiguous(remote_block_ids, local_block_ids) + if self.num_need_pulls == 1: + grouped_remote_block_ids, grouped_local_block_ids = \ + group_concurrent_contiguous(remote_block_ids, local_block_ids) + else: + remote_block_ids = list(map(lambda x: [x], remote_block_ids)) + local_block_ids = list(map(lambda x: [x], local_block_ids)) + grouped_remote_block_ids, grouped_local_block_ids = remote_block_ids, local_block_ids + num_transfer_groups = len(grouped_remote_block_ids) + remote_kv_caches_base_addrs = \ self.kv_caches_base_addr[remote_engine_id][remote_handshake_port] local_kv_caches_base_addrs = \ self.kv_caches_base_addr[self.local_engine_id][self.local_handshake_port] - - req_start_time = time.perf_counter() - num_transfer_groups = len(grouped_remote_block_ids) - num_blocks = len(local_block_ids) - remote_transfer_port = self.remote_te_port[remote_engine_id][ remote_handshake_port] + num_blocks = len(local_block_ids) session_id = f"{remote_host}:{remote_transfer_port}" + + req_start_time = time.perf_counter() src_list, dst_list, length_list = [], [], [] for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)): @@ -357,14 +377,17 @@ class KVCacheRecvingThread(threading.Thread): block_len = (self.block_len[k % 3]) else: block_len = (self.block_len[0]) - for i, remote_block_id in enumerate(grouped_remote_block_ids): - local_block_ids = grouped_local_block_ids[i] - src = src_layer_base_addr + local_block_ids[0] * block_len - dst = dst_layer_base_addr + remote_block_id[0] * block_len - length = len(local_block_ids) * block_len + inner_block_len = block_len // self.num_need_pulls + for remote_block_id, local_block_id in zip( + grouped_remote_block_ids, grouped_local_block_ids): + src = src_layer_base_addr + local_block_id[ + 0] * block_len + offset * inner_block_len + dst = dst_layer_base_addr + remote_block_id[0] * inner_block_len + length = inner_block_len * len(local_block_id) src_list.append(src) dst_list.append(dst) length_list.append(length) + ret = self.engine.batch_transfer_sync_read(session_id, src_list, dst_list, length_list) if ret < 0: @@ -376,8 +399,99 @@ class KVCacheRecvingThread(threading.Thread): req_transfer_elapsed = (req_end_time - req_start_time) * 1000 logger.info( "KV cache transfer for request %s took %.2f ms (%d groups," - " %d blocks).", request_id, req_transfer_elapsed, - num_transfer_groups, num_blocks) + " %d blocks). local_ip %s local_device_id %s remote_session_id %s", + request_id, req_transfer_elapsed, num_transfer_groups, num_blocks, + get_ip(), self.tp_rank, session_id) + if self.num_need_pulls > 1 and offset == self.num_need_pulls - 1: + self._cat_kv_cache(grouped_local_block_ids) + + def _cat_kv_cache(self, block_ids: list[list[int]]): + # Get necessary parameters + k_cache = list(self.kv_caches.values())[0][0] + kv_shape = k_cache.shape + dtype = k_cache.dtype + device = k_cache.device + head_dim = self.model_config.hf_config.head_dim + block_size = self.vllm_config.cache_config.block_size + num_kv_head = max( + self.model_config.hf_config.num_key_value_heads // self.tp_size, 1) + + flat_block_ids = [item for sublist in block_ids for item in sublist] + block_ids_tensor = torch.tensor(flat_block_ids, dtype=torch.int32) + num_blocks = len(flat_block_ids) + block_len = num_blocks * block_size + + # Create device tensors for copy operations + block_table = block_ids_tensor.view(1, -1).to(device=device) + block_len_tensor = torch.tensor([block_len], + dtype=torch.int32).to(device=device) + seq_start_tensor = torch.tensor([0], + dtype=torch.int32).to(device=device) + + # Initialize buffers + k_buffer = torch.empty(block_len, + num_kv_head, + head_dim, + dtype=dtype, + device=device) + v_buffer = torch.empty(block_len, + num_kv_head, + head_dim, + dtype=dtype, + device=device) + + # Create slot mapping for reshape operations + block_offsets = torch.arange(0, block_size, dtype=torch.int32) + slot_mapping = (block_offsets.reshape( + (1, block_size)) + block_ids_tensor.reshape( + (num_blocks, 1)) * block_size) + slot_mapping = slot_mapping.flatten().to(device=device) + + # Process each layer in the KV cache + for _, (k_cache_layer, v_cache_layer) in self.kv_caches.items(): + if len( + k_cache_layer.shape + ) == 3: # kv shape in torchair model is [num_block, block_size, num_kv_head*head_dim] + k_cache_layer = k_cache_layer.view(kv_shape[0], kv_shape[1], + num_kv_head, head_dim) + v_cache_layer = v_cache_layer.view(kv_shape[0], kv_shape[1], + num_kv_head, head_dim) + # Load cache data into buffers + torch_npu.atb.npu_paged_cache_load( + k_cache_layer, + v_cache_layer, + block_table, + block_len_tensor, + seq_starts=seq_start_tensor, + key=k_buffer, + value=v_buffer, + ) + + # Transpose KV cache + k_buffer = self._transpose_kv_cache_between_head( + k_buffer, num_blocks, block_size, block_len, num_kv_head) + v_buffer = self._transpose_kv_cache_between_head( + v_buffer, num_blocks, block_size, block_len, num_kv_head) + + # Reshape and cache the processed buffers + torch_npu._npu_reshape_and_cache( + key=k_buffer, + value=v_buffer, + key_cache=k_cache_layer, + value_cache=v_cache_layer, + slot_indices=slot_mapping, + ) + + # Clean up buffers + del k_buffer, v_buffer + + def _transpose_kv_cache_between_head(self, buffer: torch.Tensor, + num_blocks: int, block_size: int, + block_len: int, + num_kv_head: int) -> torch.Tensor: + buffer = buffer.view(num_blocks, self.num_need_pulls, block_size, -1) + buffer.transpose_(1, 2) + return buffer.contiguous().view(block_len, num_kv_head, -1) def _get_remote_metadata(self, remote_host: str, remote_handshake_port: int) -> None: @@ -573,9 +687,11 @@ class MooncakeConnectorScheduler: def __init__(self, vllm_config: VllmConfig, engine_id: str): self.vllm_config = vllm_config + init_ascend_config(vllm_config) self.ascend_config = get_ascend_config() self.block_size = vllm_config.cache_config.block_size self.engine_id = engine_id + self.local_ip = get_ip() logger.info("Initializing Mooncake Scheduler %s", engine_id) self.side_channel_host = get_ip() @@ -716,6 +832,7 @@ class MooncakeConnectorScheduler: remote_engine_id=self.engine_id, remote_host=self.side_channel_host, remote_port=self.side_channel_port, + last_token_id=request.output_token_ids[-1], ) def get_finished_count(self) -> Optional[int]: @@ -732,12 +849,23 @@ class MooncakeConnectorScheduler: "decode", {}) assert "tp_size" in decode_parallel_config.keys() self._decode_tp_size = decode_parallel_config["tp_size"] - + num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa: - return self._decode_tp_size + num_need_pulls = 1 else: - # TODO support mha and gqa - return None + num_p_block_heads = max( + 1, num_key_value_heads // self._prefill_tp_size) + num_d_block_heads = max( + 1, num_key_value_heads // self._decode_tp_size) + num_need_pulls = num_d_block_heads // num_p_block_heads + kv_role = self.vllm_config.kv_transfer_config.kv_role + logger.debug( + "get_finished_count, kv_role=%s, num_need_pulls=%d, decode_tp_size=%d", + kv_role, num_need_pulls, self._decode_tp_size) + if kv_role == 'kv_producer': + return num_need_pulls * self._decode_tp_size + else: + return self._decode_tp_size class MooncakeConnectorWorker: @@ -757,6 +885,7 @@ class MooncakeConnectorWorker: # Metadata. self.vllm_config = vllm_config + self.ascend_config = get_ascend_config() self.engine_id = engine_id self.tp_rank = get_tensor_model_parallel_rank() self.tp_size = vllm_config.parallel_config.tensor_parallel_size @@ -767,6 +896,7 @@ class MooncakeConnectorWorker: self.side_channel_host = get_ip() self.max_device_id = self.tp_size * self.dp_size self.kv_role = vllm_config.kv_transfer_config.kv_role + self.num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads # Handshake base port self.side_channel_port = ( @@ -809,8 +939,17 @@ class MooncakeConnectorWorker: self.kv_send_thread: Optional[KVCacheSendingThread] = None self.kv_recv_thread: Optional[KVCacheRecvingThread] = None + # kv_transfer variables self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size + if self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa: + self.num_need_pulls = 1 + else: + num_d_block_heads = max(1, + self.num_key_value_heads // self.tp_size) + num_p_block_heads = max( + 1, self.num_key_value_heads // self._prefill_tp_size) + self.num_need_pulls = num_d_block_heads // num_p_block_heads def _get_prefill_decode_size(self, vllm_config: VllmConfig): # get prefill tp and dp size from extra config @@ -886,15 +1025,17 @@ class MooncakeConnectorWorker: self.num_blocks, block_shape_norm, block_shape_pe, block_shape_k) else: - # [num_block, block_size, num_head, hidden_dim] + # eager:[num_block, block_size, num_head, hidden_dim] + # torchair:[num_block, block_size, num_head*hidden_dim] self.num_blocks = first_kv_cache.shape[0] kv_elem_size = first_kv_cache.element_size() - block_rank = 3 # [block_size, kv_heads, head_dim] + block_rank = len( + first_kv_cache.shape + ) - 1 # [block_size, kv_heads, head_dim] or [block_size, kv_heads*head_dim] block_shape = first_kv_cache.shape[-block_rank:] self.block_len = [kv_elem_size * math.prod(block_shape)] logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape) - logger.info( "Registering KV_Caches. use_mla: %s, use_sfa: %s, shape %s", self.use_mla, self.use_sfa, first_kv_cache.shape) @@ -935,23 +1076,21 @@ class MooncakeConnectorWorker: ready_event = threading.Event() if self.kv_role == 'kv_producer': - self.kv_send_thread = KVCacheSendingThread(self.tp_rank, - self._decode_tp_size, - self.engine_id, - self.side_channel_host, - self.side_channel_port, - metadata, ready_event) + self.kv_send_thread = KVCacheSendingThread( + self.tp_rank, self._decode_tp_size, self.engine_id, + self.side_channel_host, self.side_channel_port, metadata, + ready_event, self.kv_caches) self.kv_send_thread.start() else: self.kv_recv_thread = KVCacheRecvingThread( self.tp_rank, self.tp_size, self.engine, self.engine_id, self.handshake_port, kv_caches_base_addr, self.block_len, - ready_event) + ready_event, self.vllm_config, self.kv_caches) self.kv_recv_thread.start() ready_event.wait() def _register(self, ptr, length): - logger.info( + logger.debug( "Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, " "block_lens=%s", ptr, length, self.num_blocks, self.block_len) ret_value = self.engine.register_memory(ptr, length) @@ -982,16 +1121,21 @@ class MooncakeConnectorWorker: meta.remote_engine_id, len(meta.local_block_ids), len(meta.remote_block_ids)) - remote_handshake_port = meta.remote_port + \ - self._get_remote_tp_rank(req_id) - self.kv_recv_thread.add_request( # type: ignore[union-attr] - request_id=req_id, - local_block_ids=meta.local_block_ids, - remote_block_ids=meta.remote_block_ids, - remote_engine_id=meta.remote_engine_id, - remote_host=meta.remote_host, - remote_handshake_port=remote_handshake_port, - ) + choosen_rank_list = self._get_remote_tp_rank(req_id) + remote_handshake_port_list = [ + x + meta.remote_port for x in choosen_rank_list + ] + for i in range(self.num_need_pulls): + assert self.kv_recv_thread is not None + self.kv_recv_thread.add_request( + request_id=req_id, + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + remote_engine_id=meta.remote_engine_id, + remote_host=meta.remote_host, + remote_handshake_port=remote_handshake_port_list[i], + offset=i, + num_need_pulls=self.num_need_pulls) if self.kv_send_thread is not None: for req_id, delay_start_time in metadata.requests_to_send.items(): @@ -999,17 +1143,43 @@ class MooncakeConnectorWorker: self.kv_send_thread.add_delayed_request( req_id, delay_start_time) - def _get_remote_tp_rank(self, req_id: str) -> int: + def _get_remote_tp_rank(self, req_id: str) -> List[int]: return self._get_remote_tp_ranks_for_req(req_id)[self.tp_rank] - def _get_remote_tp_ranks_for_req(self, req_id: str) -> list[int]: + def _get_remote_tp_ranks_for_req(self, req_id: str) -> List[List[int]]: if self._prefill_tp_size == self._decode_tp_size: - return list(range(self._prefill_tp_size)) + result = list(map(lambda x: [x], range(self._prefill_tp_size))) + return result seed = string_to_int64_hash(req_id) rand = random.Random(seed) - sampled_nums = rand.sample(range(self._prefill_tp_size), - self._decode_tp_size) + sampled_nums = [] + ori_data = np.arange(self._prefill_tp_size) + # random split prefill tp list + if self._prefill_tp_size > self.num_key_value_heads or self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa: + # use deepseek mla, num_key_value_heads == 128, but consider as 1 + if self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa: + num_kv_head = 1 + else: + num_kv_head = self.num_key_value_heads + num_groups = len(ori_data) // num_kv_head + ori_data = ori_data.reshape(-1, num_groups) + rand_group_index = rand.sample(range(num_groups), \ + max(self._decode_tp_size // num_kv_head, 1)) # random choose a group + + choosen_group = ori_data[:, [rand_group_index]] + flattened = choosen_group.reshape(-1).tolist() + sampled_nums = [ + flattened[i:i + self.num_need_pulls] + for i in range(0, len(flattened), self.num_need_pulls) + ] + + # non-random split + else: + group_size = self._prefill_tp_size // self._decode_tp_size + for i in range(self._decode_tp_size): + slice = ori_data[i * group_size:(i + 1) * group_size] + sampled_nums.append(slice) return sampled_nums