diff --git a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py index 0b55c265..114a224f 100644 --- a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py +++ b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py @@ -170,36 +170,6 @@ class TestKVCacheSendingLayerThread(unittest.TestCase): self.thread._transfer_kv_cache(send_task) self.engine.batch_transfer_sync_write.assert_not_called() - def test_transfer_skips_when_tp_not_sender(self): - - thread = KVCacheSendingLayerThread( - engine=self.engine, - total_layers=2, - ready_event=self.ready_event, - tp_rank=1, - pd_head_ratio=1, - num_head_replica=2, - kv_cache_base_addr=[1000, 2000, 3000, 4000], - use_mla=False, - block_len=[1024], - decode_tp_size=1, - first_kv_cache=self.first_kv_cache, - k_buffer=MagicMock(), - v_buffer=MagicMock(), - resharding_stream=MagicMock(), - callback_func=MagicMock()) - req_meta = self.req_meta_base - send_task = SendTask( - send_request={"req3": req_meta}, - wait_event=MagicMock(), - k_cache=self.key, - v_cache=self.value, - layer_idx=1, - rearrange_block_ids=[], - ) - thread._transfer_kv_cache(send_task) - self.engine.batch_transfer_sync_write.assert_not_called() - @patch( "vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous", side_effect=group_concurrent_contiguous) @@ -425,6 +395,7 @@ class MockVllmConfig: self.parallel_config.data_parallel_size = 1 self.parallel_config.data_parallel_rank = 0 self.cache_config.block_size = 16 + self.model_config.hf_config.num_key_value_heads = 1 self.kv_transfer_config.engine_id = "test_engine" self.kv_transfer_config.kv_port = 5000 diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index a32817dc..7673322c 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -108,6 +108,7 @@ class AscendSFAMetadata: # chunked prefill by default if no attn_states passed attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill sfa_cp_context: Optional[SfaCpContext] = None + reshape_cache_event: torch.npu.Event = None M = TypeVar("M", bound=AscendSFAMetadata) @@ -369,6 +370,7 @@ class AscendSFAImpl(MLAAttentionImpl): self.enable_sfa_cp = enable_dsa_cp() self.local_num_heads = self.num_heads self.vllm_config = get_current_vllm_config() + self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer if self.enable_sfa_cp: self.local_num_heads = self.num_heads * self.tp_size self.layer_sharding_kwargs = [] @@ -897,11 +899,15 @@ class AscendSFAImpl(MLAAttentionImpl): k = get_tp_group().all_gather(k, 0) if kv_cache is not None: + if self.is_kv_producer: + attn_metadata.reshape_cache_event = torch.npu.Event() torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]), attn_metadata.slot_mapping.view( -1, 1), k.view(-1, k.shape[-1])) # b, s, n, d + if self.is_kv_producer: + attn_metadata.reshape_cache_event.record() weights, _ = self.weights_proj(x) weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( diff --git a/vllm_ascend/distributed/mooncake_layerwise_connector.py b/vllm_ascend/distributed/mooncake_layerwise_connector.py index d783e04f..e7e3219a 100644 --- a/vllm_ascend/distributed/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/mooncake_layerwise_connector.py @@ -162,6 +162,7 @@ class KVCacheSendingLayerThread(threading.Thread): self.kv_caches_base_addr = kv_cache_base_addr self.total_layers = total_layers self.use_mla = use_mla + self.use_sparse = len(block_len) == 3 self.block_len = block_len self._decode_tp_size = decode_tp_size self.resharding_stream = resharding_stream @@ -195,17 +196,6 @@ class KVCacheSendingLayerThread(threading.Thread): src_list: list[str] = [] dst_list: list[str] = [] length_list: list[int] = [] - # not need to send kv cache - if self.tp_rank % self.num_head_replica != 0: - logger.debug( - f"Cancelling KV cache transfer for request {req_id}. Reason: TP rank excluded from head replication (TP Rank: {self.tp_rank}, Replicas: {self.num_head_replica})." - ) - return (src_list, dst_list, length_list) - if self.use_mla and self.tp_rank >= self._decode_tp_size: - logger.debug( - f"Cancelling KV cache transfer for request {req_id}. Reason: MLA mode active and TP rank outside decoding group (TP Rank: {self.tp_rank}, Decode TP Size: {self._decode_tp_size})." - ) - return (src_list, dst_list, length_list) layer_idx = send_task.layer_idx remote_block_ids = req_meta.remote_block_ids @@ -214,21 +204,36 @@ class KVCacheSendingLayerThread(threading.Thread): local_block_ids = req_meta.local_block_ids if self.pd_head_ratio == 1: - layer_local_kv_base_addr = [ - local_kv_base_addr[i] - for i in [2 * layer_idx, 2 * layer_idx + 1] - ] - layer_remote_kv_base_addr = [ - remote_kv_base_addrs[i] # type:ignore - for i in [2 * layer_idx, 2 * layer_idx + 1] - ] + if self.use_sparse: + layer_local_kv_base_addr = [ + local_kv_base_addr[i] for i in + [3 * layer_idx, 3 * layer_idx + 1, 3 * layer_idx + 2] + ] + layer_remote_kv_base_addr = [ + remote_kv_base_addrs[i] # type:ignore + for i in + [3 * layer_idx, 3 * layer_idx + 1, 3 * layer_idx + 2] + ] + else: + layer_local_kv_base_addr = [ + local_kv_base_addr[i] + for i in [2 * layer_idx, 2 * layer_idx + 1] + ] + layer_remote_kv_base_addr = [ + remote_kv_base_addrs[i] # type:ignore + for i in [2 * layer_idx, 2 * layer_idx + 1] + ] grouped_remote_block_ids, grouped_local_block_ids = \ group_concurrent_contiguous(remote_block_ids, local_block_ids) for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)): - block_len = self.block_len[ - k % 2] if self.use_mla else self.block_len[0] + if self.use_mla: + block_len = (self.block_len[k % 2]) + elif self.use_sparse: + block_len = (self.block_len[k % 3]) + else: + block_len = (self.block_len[0]) for group_remote_block_id, group_local_block_id in zip( grouped_remote_block_ids, grouped_local_block_ids): src = src_layer_base_addr + group_local_block_id[ @@ -931,7 +936,9 @@ class MooncakeLayerwiseConnectorWorker: # TODO(tms): Find a more robust way to detect and handle MLA self.use_mla = first_kv_cache_tuple[0].size( - -1) != first_kv_cache_tuple[1].size(-1) + -1) != first_kv_cache_tuple[1].size(-1) and len( + first_kv_cache_tuple) == 2 + self.use_sparse = len(first_kv_cache_tuple) == 3 if self.use_mla: # MLA case.[num_block, block_size, 1, hidden_dim] self.num_blocks = first_kv_cache.shape[0] @@ -945,6 +952,21 @@ class MooncakeLayerwiseConnectorWorker: logger.info( "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", self.num_blocks, block_shape_norm, block_shape_pe) + elif self.use_sparse: + self.num_blocks = first_kv_cache.shape[0] + block_rank = 3 # [block_size, latent_dim] + block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] + block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] + block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:] + self.block_len = [ + first_kv_cache[0].element_size() * math.prod(block_shape_norm), + first_kv_cache[1].element_size() * math.prod(block_shape_pe), + first_kv_cache[2].element_size() * math.prod(block_shape_k) + ] + logger.info( + "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s", + self.num_blocks, block_shape_norm, block_shape_pe, + block_shape_k) else: # [num_block, block_size, num_head, hidden_dim] self.num_blocks = first_kv_cache.shape[0] @@ -955,8 +977,9 @@ class MooncakeLayerwiseConnectorWorker: logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape) - logger.info("Registering KV_Caches. use_mla: %s, shape %s", - self.use_mla, first_kv_cache.shape) + logger.info( + "Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s", + self.use_mla, self.use_sparse, first_kv_cache.shape) self.kv_caches = kv_caches kv_caches_base_addr = [] @@ -971,9 +994,17 @@ class MooncakeLayerwiseConnectorWorker: kv_caches_base_addr.append(base_addr) ptrs.append(base_addr) lengths.append(region_len) + elif self.use_sparse: + for i, cache in enumerate(cache_or_caches, 0): + base_addr = cache.data_ptr() + region_len = self.num_blocks * self.block_len[i % 3] + kv_caches_base_addr.append(base_addr) + ptrs.append(base_addr) + lengths.append(region_len) else: - cache_list = [cache_or_caches - ] if self.use_mla else cache_or_caches + cache_list = [ + cache_or_caches + ] if self.use_mla or self.use_sparse else cache_or_caches for cache in cache_list: base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len[0] @@ -1046,56 +1077,72 @@ class MooncakeLayerwiseConnectorWorker: if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys( ): # enable decode prefix cache - if self.use_mla: - reshape_cache_event = attn_metadata[ - layer_name].reshape_cache_event + if self.use_mla or self.use_sparse: + num_kv_head = self._decode_tp_size else: - reshape_cache_event = attn_metadata.reshape_cache_event + num_kv_head = self.vllm_config.model_config.hf_config.num_key_value_heads + num_replica_groups = self.tp_size // num_kv_head if self.tp_size >= num_kv_head else 1 + replica_group_idx = self.tp_rank % num_replica_groups + req_ids = sorted(list(connector_metadata.requests.keys())) + selected_req_ids = [ + req_id for i, req_id in enumerate(req_ids) + if i % num_replica_groups == replica_group_idx + ] + if selected_req_ids: + if self.use_mla or self.use_sparse: + reshape_cache_event = attn_metadata[ + layer_name].reshape_cache_event + else: + reshape_cache_event = attn_metadata.reshape_cache_event - if self.pd_head_ratio != 1: - assert self.resharding_stream is not None - with npu_stream_switch(self.resharding_stream): - reshape_cache_event.wait() - rearrange_block_ids = sorted({ - block_id - for request in connector_metadata.requests.values() - for block_id in request.local_block_ids - }) + if self.pd_head_ratio != 1: + assert self.resharding_stream is not None + with npu_stream_switch(self.resharding_stream): + reshape_cache_event.wait() + rearrange_block_ids = sorted({ + block_id + for req_id in selected_req_ids + for block_id in + connector_metadata.requests[req_id].local_block_ids + }) - keys = kv_layer[0][rearrange_block_ids].clone() - values = kv_layer[1][rearrange_block_ids].clone() - # sort kv caches for each block - keys = keys.view(keys.size(0), self.pd_head_ratio, -1, - *keys.shape[2:]).transpose( - 0, 1).reshape_as(keys) - values = values.view(values.size(0), self.pd_head_ratio, - -1, *values.shape[2:]).transpose( - 0, 1).reshape_as(values) - # reshard kv cache - keys = keys.reshape(-1, *kv_layer[0].shape[2:]) - values = values.reshape(-1, *kv_layer[1].shape[2:]) - (keys, values) = kv_alltoall_and_rearrange( - self.pd_head_ratio, keys, values) - else: - keys = None - values = None - rearrange_block_ids = None + keys = kv_layer[0][rearrange_block_ids].clone() + values = kv_layer[1][rearrange_block_ids].clone() + # sort kv caches for each block + keys = keys.view(keys.size(0), self.pd_head_ratio, -1, + *keys.shape[2:]).transpose( + 0, 1).reshape_as(keys) + values = values.view(values.size(0), + self.pd_head_ratio, -1, + *values.shape[2:]).transpose( + 0, 1).reshape_as(values) + # reshard kv cache + keys = keys.reshape(-1, *kv_layer[0].shape[2:]) + values = values.reshape(-1, *kv_layer[1].shape[2:]) + (keys, values) = kv_alltoall_and_rearrange( + self.pd_head_ratio, keys, values) + else: + keys = None + values = None + rearrange_block_ids = None - assert self.kv_send_layer_thread is not None - assert reshape_cache_event is not None - send_task = SendTask(wait_event=reshape_cache_event, - k_cache=keys, - v_cache=values, - layer_idx=self.current_layer, - rearrange_block_ids=rearrange_block_ids) - for req_id, req_meta in connector_metadata.requests.items(): - req_meta_update = self.update_decoder_info(req_id, req_meta) - logger.debug( - f"Add request {req_id} to kv send layer thread. {req_meta_update=}" - ) - send_task.send_request[req_id] = req_meta_update + assert self.kv_send_layer_thread is not None + assert reshape_cache_event is not None + send_task = SendTask(wait_event=reshape_cache_event, + k_cache=keys, + v_cache=values, + layer_idx=self.current_layer, + rearrange_block_ids=rearrange_block_ids) + for req_id, req_meta in connector_metadata.requests.items(): + if req_id in selected_req_ids: + req_meta_update = self.update_decoder_info( + req_id, req_meta) + logger.debug( + f"Add request {req_id} to kv send layer thread. {req_meta_update=}" + ) + send_task.send_request[req_id] = req_meta_update - self.kv_send_layer_thread.send_queue.put(send_task) + self.kv_send_layer_thread.send_queue.put(send_task) self.current_layer += 1 def _get_remote_socket( @@ -1121,8 +1168,13 @@ class MooncakeLayerwiseConnectorWorker: def update_decoder_info(self, req_id, req_meta): req_meta_update = copy.deepcopy(req_meta) - req_meta_update.remote_port = req_meta_update.remote_port + ( - self.tp_rank // self.pd_tp_ratio) % self._decode_tp_size + if self.use_mla or self.use_sparse: + pd_tp_ratio = self.tp_size // self._decode_tp_size + req_meta_update.remote_port = req_meta_update.remote_port + ( + self.tp_rank // pd_tp_ratio) % self._decode_tp_size + else: + req_meta_update.remote_port = req_meta_update.remote_port + ( + self.tp_rank // self.pd_tp_ratio) % self._decode_tp_size if req_meta_update.remote_engine_id not in self.remote_kv_caches_base_addr or \ req_meta_update.remote_port not in self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id]: try: @@ -1146,14 +1198,16 @@ class MooncakeLayerwiseConnectorWorker: logger.info( f"Query to port and kv base addr for request {req_id} from {req_meta_update.remote_host}:{req_meta_update.remote_port} success {agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}" ) - session_id = f"{req_meta_update.remote_host}:{agent_meta.te_rpc_port}" - ret = self.engine.batch_transfer_sync_write( - session_id, [self.kv_caches_base_addr[0]], - [agent_meta.kv_caches_base_addr[0]], [128]) - if ret < 0: - logger.error( - f"Mooncake transfer failed to create link to device {session_id}" - ) + if self.pd_head_ratio > 1: + # for tp inequal, pre-create link to prevent alltoall out of memory + session_id = f"{req_meta_update.remote_host}:{agent_meta.te_rpc_port}" + ret = self.engine.batch_transfer_sync_write( + session_id, [self.kv_caches_base_addr[0]], + [agent_meta.kv_caches_base_addr[0]], [128]) + if ret < 0: + logger.error( + f"Mooncake transfer failed to create link to device {session_id}" + ) req_meta_update.remote_te_rpc_port = self.remote_te_port[ req_meta_update.remote_engine_id][req_meta_update.remote_port] req_meta_update.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[