diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 6ecf8e7..dcdfdf6 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -17,6 +17,7 @@ import msgspec import numpy as np import numpy.typing as npt import torch +import torch_npu import zmq from mooncake.engine import TransferEngine # type: ignore from vllm import envs @@ -30,7 +31,7 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import RequestStatus import vllm_ascend.envs as envs_ascend -from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -127,8 +128,8 @@ class KVCacheSendingThread(threading.Thread): def __init__(self, tp_rank: int, decode_tp_size: int, local_engine_id: str, side_channel_host: str, side_channel_port: int, - metadata: MooncakeAgentMetadata, - ready_event: threading.Event): + metadata: MooncakeAgentMetadata, ready_event: threading.Event, + kv_caches: dict[str, Any]): super().__init__(daemon=True, name="KVCacheSendingThread") self.tp_rank = tp_rank self.decode_tp_size = decode_tp_size @@ -137,6 +138,7 @@ class KVCacheSendingThread(threading.Thread): self.side_channel_port = side_channel_port self.metadata = metadata self.ready_event = ready_event + self.kv_caches = kv_caches self.task_tracker = KVCacheTaskTracker() @@ -220,7 +222,8 @@ class KVCacheRecvingThread(threading.Thread): def __init__(self, tp_rank: int, tp_size: int, engine: TransferEngine, local_engine_id: str, local_handshake_port: int, local_kv_caches_base_addr: list[int], block_len: list[int], - ready_event: threading.Event): + ready_event: threading.Event, vllm_config: VllmConfig, + kv_caches: dict[str, Any]): super().__init__(daemon=True, name="KVCacheRecvingThread") self.tp_rank = tp_rank self.tp_size = tp_size @@ -242,7 +245,6 @@ class KVCacheRecvingThread(threading.Thread): self.use_sfa = len(block_len) == 3 self.request_queue: queue.Queue[Any] = queue.Queue() - # TODO(jianzs): make this configurable self.executor = ThreadPoolExecutor(max_workers=32) self.task_tracker = KVCacheTaskTracker() @@ -256,9 +258,15 @@ class KVCacheRecvingThread(threading.Thread): self.remote_poller = zmq.Poller() # type: ignore self.timeout = 1.0 # seconds + self.vllm_config = vllm_config + self.model_config = self.vllm_config.model_config + self.num_key_value_heads = self.model_config.hf_config.num_key_value_heads + self.kv_caches = kv_caches + def add_request(self, request_id: str, local_block_ids: list[int], remote_block_ids: list[int], remote_engine_id: str, - remote_host: str, remote_handshake_port: int): + remote_host: str, remote_handshake_port: int, offset: int, + num_need_pulls: int): """Add a new request to the queue for processing.""" logger.debug(f"Adding request {request_id} to the queue.") self.request_queue.put({ @@ -268,6 +276,8 @@ class KVCacheRecvingThread(threading.Thread): "remote_engine_id": remote_engine_id, "remote_host": remote_host, "remote_handshake_port": remote_handshake_port, + "offset": offset, + "num_need_pulls": num_need_pulls }) def get_and_clear_finished_requests(self) -> set[str]: @@ -296,6 +306,8 @@ class KVCacheRecvingThread(threading.Thread): request_id = req_meta["request_id"] remote_host = req_meta["remote_host"] remote_handshake_port = req_meta["remote_handshake_port"] + offset = req_meta["offset"] + num_need_pulls = req_meta["num_need_pulls"] try: logger.debug( @@ -307,12 +319,13 @@ class KVCacheRecvingThread(threading.Thread): logger.error("Failed to transfer KV cache for request " f"{request_id}: {e}") finally: - self.task_tracker.update_done_task_count(request_id) # Always send the done signal to the remote host to ensure proper # resource cleanup. Failing to do so may cause a memory leak on the # remote host. self._send_done_recv_signal(request_id, remote_host, remote_handshake_port) + if offset == num_need_pulls - 1: + self.task_tracker.update_done_task_count(request_id) self.request_queue.task_done() def _transfer_kv_cache(self, req_meta: dict[str, Any]): @@ -323,6 +336,8 @@ class KVCacheRecvingThread(threading.Thread): remote_engine_id = req_meta["remote_engine_id"] remote_host = req_meta["remote_host"] remote_handshake_port = req_meta["remote_handshake_port"] + offset = req_meta["offset"] + self.num_need_pulls = req_meta["num_need_pulls"] # Full prefix cache hit: do not need to read remote blocks, just notify # P worker that we have the blocks we need. @@ -331,23 +346,28 @@ class KVCacheRecvingThread(threading.Thread): # Check if we have the remote metadata cached. if remote_engine_id not in self.kv_caches_base_addr or \ - remote_handshake_port not in self.kv_caches_base_addr[remote_engine_id]: + remote_handshake_port not in self.kv_caches_base_addr[remote_engine_id]: self._get_remote_metadata(remote_host, remote_handshake_port) - grouped_remote_block_ids, grouped_local_block_ids = \ - group_concurrent_contiguous(remote_block_ids, local_block_ids) + if self.num_need_pulls == 1: + grouped_remote_block_ids, grouped_local_block_ids = \ + group_concurrent_contiguous(remote_block_ids, local_block_ids) + else: + remote_block_ids = list(map(lambda x: [x], remote_block_ids)) + local_block_ids = list(map(lambda x: [x], local_block_ids)) + grouped_remote_block_ids, grouped_local_block_ids = remote_block_ids, local_block_ids + num_transfer_groups = len(grouped_remote_block_ids) + remote_kv_caches_base_addrs = \ self.kv_caches_base_addr[remote_engine_id][remote_handshake_port] local_kv_caches_base_addrs = \ self.kv_caches_base_addr[self.local_engine_id][self.local_handshake_port] - - req_start_time = time.perf_counter() - num_transfer_groups = len(grouped_remote_block_ids) - num_blocks = len(local_block_ids) - remote_transfer_port = self.remote_te_port[remote_engine_id][ remote_handshake_port] + num_blocks = len(local_block_ids) session_id = f"{remote_host}:{remote_transfer_port}" + + req_start_time = time.perf_counter() src_list, dst_list, length_list = [], [], [] for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)): @@ -357,14 +377,17 @@ class KVCacheRecvingThread(threading.Thread): block_len = (self.block_len[k % 3]) else: block_len = (self.block_len[0]) - for i, remote_block_id in enumerate(grouped_remote_block_ids): - local_block_ids = grouped_local_block_ids[i] - src = src_layer_base_addr + local_block_ids[0] * block_len - dst = dst_layer_base_addr + remote_block_id[0] * block_len - length = len(local_block_ids) * block_len + inner_block_len = block_len // self.num_need_pulls + for remote_block_id, local_block_id in zip( + grouped_remote_block_ids, grouped_local_block_ids): + src = src_layer_base_addr + local_block_id[ + 0] * block_len + offset * inner_block_len + dst = dst_layer_base_addr + remote_block_id[0] * inner_block_len + length = inner_block_len * len(local_block_id) src_list.append(src) dst_list.append(dst) length_list.append(length) + ret = self.engine.batch_transfer_sync_read(session_id, src_list, dst_list, length_list) if ret < 0: @@ -376,8 +399,99 @@ class KVCacheRecvingThread(threading.Thread): req_transfer_elapsed = (req_end_time - req_start_time) * 1000 logger.info( "KV cache transfer for request %s took %.2f ms (%d groups," - " %d blocks).", request_id, req_transfer_elapsed, - num_transfer_groups, num_blocks) + " %d blocks). local_ip %s local_device_id %s remote_session_id %s", + request_id, req_transfer_elapsed, num_transfer_groups, num_blocks, + get_ip(), self.tp_rank, session_id) + if self.num_need_pulls > 1 and offset == self.num_need_pulls - 1: + self._cat_kv_cache(grouped_local_block_ids) + + def _cat_kv_cache(self, block_ids: list[list[int]]): + # Get necessary parameters + k_cache = list(self.kv_caches.values())[0][0] + kv_shape = k_cache.shape + dtype = k_cache.dtype + device = k_cache.device + head_dim = self.model_config.hf_config.head_dim + block_size = self.vllm_config.cache_config.block_size + num_kv_head = max( + self.model_config.hf_config.num_key_value_heads // self.tp_size, 1) + + flat_block_ids = [item for sublist in block_ids for item in sublist] + block_ids_tensor = torch.tensor(flat_block_ids, dtype=torch.int32) + num_blocks = len(flat_block_ids) + block_len = num_blocks * block_size + + # Create device tensors for copy operations + block_table = block_ids_tensor.view(1, -1).to(device=device) + block_len_tensor = torch.tensor([block_len], + dtype=torch.int32).to(device=device) + seq_start_tensor = torch.tensor([0], + dtype=torch.int32).to(device=device) + + # Initialize buffers + k_buffer = torch.empty(block_len, + num_kv_head, + head_dim, + dtype=dtype, + device=device) + v_buffer = torch.empty(block_len, + num_kv_head, + head_dim, + dtype=dtype, + device=device) + + # Create slot mapping for reshape operations + block_offsets = torch.arange(0, block_size, dtype=torch.int32) + slot_mapping = (block_offsets.reshape( + (1, block_size)) + block_ids_tensor.reshape( + (num_blocks, 1)) * block_size) + slot_mapping = slot_mapping.flatten().to(device=device) + + # Process each layer in the KV cache + for _, (k_cache_layer, v_cache_layer) in self.kv_caches.items(): + if len( + k_cache_layer.shape + ) == 3: # kv shape in torchair model is [num_block, block_size, num_kv_head*head_dim] + k_cache_layer = k_cache_layer.view(kv_shape[0], kv_shape[1], + num_kv_head, head_dim) + v_cache_layer = v_cache_layer.view(kv_shape[0], kv_shape[1], + num_kv_head, head_dim) + # Load cache data into buffers + torch_npu.atb.npu_paged_cache_load( + k_cache_layer, + v_cache_layer, + block_table, + block_len_tensor, + seq_starts=seq_start_tensor, + key=k_buffer, + value=v_buffer, + ) + + # Transpose KV cache + k_buffer = self._transpose_kv_cache_between_head( + k_buffer, num_blocks, block_size, block_len, num_kv_head) + v_buffer = self._transpose_kv_cache_between_head( + v_buffer, num_blocks, block_size, block_len, num_kv_head) + + # Reshape and cache the processed buffers + torch_npu._npu_reshape_and_cache( + key=k_buffer, + value=v_buffer, + key_cache=k_cache_layer, + value_cache=v_cache_layer, + slot_indices=slot_mapping, + ) + + # Clean up buffers + del k_buffer, v_buffer + + def _transpose_kv_cache_between_head(self, buffer: torch.Tensor, + num_blocks: int, block_size: int, + block_len: int, + num_kv_head: int) -> torch.Tensor: + buffer = buffer.view(num_blocks, self.num_need_pulls, block_size, -1) + buffer.transpose_(1, 2) + return buffer.contiguous().view(block_len, num_kv_head, -1) def _get_remote_metadata(self, remote_host: str, remote_handshake_port: int) -> None: @@ -573,9 +687,11 @@ class MooncakeConnectorScheduler: def __init__(self, vllm_config: VllmConfig, engine_id: str): self.vllm_config = vllm_config + init_ascend_config(vllm_config) self.ascend_config = get_ascend_config() self.block_size = vllm_config.cache_config.block_size self.engine_id = engine_id + self.local_ip = get_ip() logger.info("Initializing Mooncake Scheduler %s", engine_id) self.side_channel_host = get_ip() @@ -716,6 +832,7 @@ class MooncakeConnectorScheduler: remote_engine_id=self.engine_id, remote_host=self.side_channel_host, remote_port=self.side_channel_port, + last_token_id=request.output_token_ids[-1], ) def get_finished_count(self) -> Optional[int]: @@ -732,12 +849,23 @@ class MooncakeConnectorScheduler: "decode", {}) assert "tp_size" in decode_parallel_config.keys() self._decode_tp_size = decode_parallel_config["tp_size"] - + num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa: - return self._decode_tp_size + num_need_pulls = 1 else: - # TODO support mha and gqa - return None + num_p_block_heads = max( + 1, num_key_value_heads // self._prefill_tp_size) + num_d_block_heads = max( + 1, num_key_value_heads // self._decode_tp_size) + num_need_pulls = num_d_block_heads // num_p_block_heads + kv_role = self.vllm_config.kv_transfer_config.kv_role + logger.debug( + "get_finished_count, kv_role=%s, num_need_pulls=%d, decode_tp_size=%d", + kv_role, num_need_pulls, self._decode_tp_size) + if kv_role == 'kv_producer': + return num_need_pulls * self._decode_tp_size + else: + return self._decode_tp_size class MooncakeConnectorWorker: @@ -757,6 +885,7 @@ class MooncakeConnectorWorker: # Metadata. self.vllm_config = vllm_config + self.ascend_config = get_ascend_config() self.engine_id = engine_id self.tp_rank = get_tensor_model_parallel_rank() self.tp_size = vllm_config.parallel_config.tensor_parallel_size @@ -767,6 +896,7 @@ class MooncakeConnectorWorker: self.side_channel_host = get_ip() self.max_device_id = self.tp_size * self.dp_size self.kv_role = vllm_config.kv_transfer_config.kv_role + self.num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads # Handshake base port self.side_channel_port = ( @@ -809,8 +939,17 @@ class MooncakeConnectorWorker: self.kv_send_thread: Optional[KVCacheSendingThread] = None self.kv_recv_thread: Optional[KVCacheRecvingThread] = None + # kv_transfer variables self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size + if self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa: + self.num_need_pulls = 1 + else: + num_d_block_heads = max(1, + self.num_key_value_heads // self.tp_size) + num_p_block_heads = max( + 1, self.num_key_value_heads // self._prefill_tp_size) + self.num_need_pulls = num_d_block_heads // num_p_block_heads def _get_prefill_decode_size(self, vllm_config: VllmConfig): # get prefill tp and dp size from extra config @@ -886,15 +1025,17 @@ class MooncakeConnectorWorker: self.num_blocks, block_shape_norm, block_shape_pe, block_shape_k) else: - # [num_block, block_size, num_head, hidden_dim] + # eager:[num_block, block_size, num_head, hidden_dim] + # torchair:[num_block, block_size, num_head*hidden_dim] self.num_blocks = first_kv_cache.shape[0] kv_elem_size = first_kv_cache.element_size() - block_rank = 3 # [block_size, kv_heads, head_dim] + block_rank = len( + first_kv_cache.shape + ) - 1 # [block_size, kv_heads, head_dim] or [block_size, kv_heads*head_dim] block_shape = first_kv_cache.shape[-block_rank:] self.block_len = [kv_elem_size * math.prod(block_shape)] logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape) - logger.info( "Registering KV_Caches. use_mla: %s, use_sfa: %s, shape %s", self.use_mla, self.use_sfa, first_kv_cache.shape) @@ -935,23 +1076,21 @@ class MooncakeConnectorWorker: ready_event = threading.Event() if self.kv_role == 'kv_producer': - self.kv_send_thread = KVCacheSendingThread(self.tp_rank, - self._decode_tp_size, - self.engine_id, - self.side_channel_host, - self.side_channel_port, - metadata, ready_event) + self.kv_send_thread = KVCacheSendingThread( + self.tp_rank, self._decode_tp_size, self.engine_id, + self.side_channel_host, self.side_channel_port, metadata, + ready_event, self.kv_caches) self.kv_send_thread.start() else: self.kv_recv_thread = KVCacheRecvingThread( self.tp_rank, self.tp_size, self.engine, self.engine_id, self.handshake_port, kv_caches_base_addr, self.block_len, - ready_event) + ready_event, self.vllm_config, self.kv_caches) self.kv_recv_thread.start() ready_event.wait() def _register(self, ptr, length): - logger.info( + logger.debug( "Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, " "block_lens=%s", ptr, length, self.num_blocks, self.block_len) ret_value = self.engine.register_memory(ptr, length) @@ -982,16 +1121,21 @@ class MooncakeConnectorWorker: meta.remote_engine_id, len(meta.local_block_ids), len(meta.remote_block_ids)) - remote_handshake_port = meta.remote_port + \ - self._get_remote_tp_rank(req_id) - self.kv_recv_thread.add_request( # type: ignore[union-attr] - request_id=req_id, - local_block_ids=meta.local_block_ids, - remote_block_ids=meta.remote_block_ids, - remote_engine_id=meta.remote_engine_id, - remote_host=meta.remote_host, - remote_handshake_port=remote_handshake_port, - ) + choosen_rank_list = self._get_remote_tp_rank(req_id) + remote_handshake_port_list = [ + x + meta.remote_port for x in choosen_rank_list + ] + for i in range(self.num_need_pulls): + assert self.kv_recv_thread is not None + self.kv_recv_thread.add_request( + request_id=req_id, + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + remote_engine_id=meta.remote_engine_id, + remote_host=meta.remote_host, + remote_handshake_port=remote_handshake_port_list[i], + offset=i, + num_need_pulls=self.num_need_pulls) if self.kv_send_thread is not None: for req_id, delay_start_time in metadata.requests_to_send.items(): @@ -999,17 +1143,43 @@ class MooncakeConnectorWorker: self.kv_send_thread.add_delayed_request( req_id, delay_start_time) - def _get_remote_tp_rank(self, req_id: str) -> int: + def _get_remote_tp_rank(self, req_id: str) -> List[int]: return self._get_remote_tp_ranks_for_req(req_id)[self.tp_rank] - def _get_remote_tp_ranks_for_req(self, req_id: str) -> list[int]: + def _get_remote_tp_ranks_for_req(self, req_id: str) -> List[List[int]]: if self._prefill_tp_size == self._decode_tp_size: - return list(range(self._prefill_tp_size)) + result = list(map(lambda x: [x], range(self._prefill_tp_size))) + return result seed = string_to_int64_hash(req_id) rand = random.Random(seed) - sampled_nums = rand.sample(range(self._prefill_tp_size), - self._decode_tp_size) + sampled_nums = [] + ori_data = np.arange(self._prefill_tp_size) + # random split prefill tp list + if self._prefill_tp_size > self.num_key_value_heads or self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa: + # use deepseek mla, num_key_value_heads == 128, but consider as 1 + if self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa: + num_kv_head = 1 + else: + num_kv_head = self.num_key_value_heads + num_groups = len(ori_data) // num_kv_head + ori_data = ori_data.reshape(-1, num_groups) + rand_group_index = rand.sample(range(num_groups), \ + max(self._decode_tp_size // num_kv_head, 1)) # random choose a group + + choosen_group = ori_data[:, [rand_group_index]] + flattened = choosen_group.reshape(-1).tolist() + sampled_nums = [ + flattened[i:i + self.num_need_pulls] + for i in range(0, len(flattened), self.num_need_pulls) + ] + + # non-random split + else: + group_size = self._prefill_tp_size // self._decode_tp_size + for i in range(self._decode_tp_size): + slice = ori_data[i * group_size:(i + 1) * group_size] + sampled_nums.append(slice) return sampled_nums